mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,115 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ from abc import ABC, abstractmethod
18
+
19
+ from msprobe.core.config_check.utils.utils import config_checking_print
20
+ from msprobe.core.common.file_utils import FileOpen, load_yaml
21
+ from msprobe.core.common.const import Const, FileCheckConst
22
+
23
+
24
+ class Parser(ABC):
25
+ @abstractmethod
26
+ def parse(self, file_path: str) -> dict:
27
+ pass
28
+
29
+ def run(self, file_path: str) -> dict:
30
+ """
31
+ 统一对外调用接口
32
+ :param file_path: 需解析的文件路径
33
+ :return:
34
+ """
35
+ try:
36
+ result = self.parse(file_path)
37
+ except Exception as exc:
38
+ config_checking_print(f"{self.__class__} parsing error, skip file path: {file_path}, error: {exc}")
39
+ result = {}
40
+ return result
41
+
42
+
43
+ class ShellParser(Parser):
44
+ def parse(self, file_path: str) -> dict:
45
+ """
46
+ Extracts arguments from bash script used to run a model training.
47
+ """
48
+ hyperparameters = {}
49
+ script_content_list = []
50
+ with FileOpen(file_path, 'r') as file:
51
+ for line in file:
52
+ stripped_line = line.lstrip()
53
+ if not stripped_line.startswith('#'):
54
+ line = line.split('#')[0].rstrip() + '\n'
55
+ if line.strip():
56
+ script_content_list.append(line)
57
+ script_content = ''.join(script_content_list)
58
+
59
+ command_line = re.search(r'msrun\s[^|]*|torchrun\s[^|]*|python\d? -m torch.distributed.launch\s[^|]*',
60
+ script_content,
61
+ re.DOTALL)
62
+ if command_line:
63
+ command_line = command_line.group()
64
+
65
+ blocks = re.findall(r'([a-zA-Z0-9_]{1,20}_ARGS)="(.*?)"', script_content, re.DOTALL)
66
+ block_contents = {}
67
+ for block_name, block_content in blocks:
68
+ block_content = block_content.replace('\n', ' ')
69
+ block_contents[block_name] = block_content
70
+ command_line = command_line.replace(f"${block_name}", block_content)
71
+
72
+ matches = re.findall(r'--([\w-]+)(?:\s+([^\s\\]+))?', command_line)
73
+ for match in matches:
74
+ key, value = match
75
+ args_key = re.match(r'\$\{?(\w+)}?', value)
76
+ if args_key:
77
+ env_vars = re.findall(rf'{args_key.group(1)}=\s*(.+)', script_content)
78
+ if env_vars:
79
+ value = env_vars[-1]
80
+ hyperparameters[key] = value if value else True
81
+
82
+ return hyperparameters
83
+
84
+
85
+ class YamlParser(Parser):
86
+ hyperparameters = {}
87
+
88
+ def parse(self, file_path: str) -> dict:
89
+ ori_hyper = load_yaml(file_path)
90
+ self.recursive_parse_parameters(ori_hyper, "")
91
+ return self.hyperparameters
92
+
93
+ def recursive_parse_parameters(self, parameters, prefix):
94
+ if isinstance(parameters, dict):
95
+ for key, value in parameters.items():
96
+ new_prefix = prefix + Const.SEP + key if prefix else key
97
+ self.recursive_parse_parameters(value, new_prefix)
98
+ elif isinstance(parameters, list):
99
+ for value in parameters:
100
+ self.recursive_parse_parameters(value, prefix)
101
+ elif isinstance(parameters, (int, str, bool)):
102
+ self.hyperparameters.update({prefix: parameters})
103
+
104
+
105
+ class ParserFactory:
106
+ __ParserDict = {
107
+ FileCheckConst.SHELL_SUFFIX: ShellParser(),
108
+ FileCheckConst.YAML_SUFFIX: YamlParser()
109
+ }
110
+
111
+ def get_parser(self, file_type: str) -> Parser:
112
+ parser = self.__ParserDict.get(file_type, None)
113
+ if not parser:
114
+ raise ValueError(f'Invalid parser type: {file_type}')
115
+ return parser
@@ -0,0 +1,107 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import re
18
+ import hashlib
19
+
20
+ from msprobe.core.common.framework_adapter import FmkAdp
21
+ from msprobe.core.common.log import logger
22
+
23
+
24
+ def merge_keys(dir_0, dir_1):
25
+ output_list = list(dir_0.keys())
26
+ output_list.extend(list(dir_1.keys()))
27
+ return set(output_list)
28
+
29
+
30
+ def compare_dict(bench_dict, cmp_dict):
31
+ result = []
32
+ for key in set(bench_dict.keys()) | set(cmp_dict.keys()):
33
+ if key in bench_dict and key in cmp_dict:
34
+ if bench_dict[key] != cmp_dict[key]:
35
+ result.append(f"{key}: {bench_dict[key]} -> {cmp_dict[key]}")
36
+ elif key in bench_dict:
37
+ result.append(f"{key}: [deleted] -> {bench_dict[key]}")
38
+ else:
39
+ result.append(f"{key}: [added] -> {cmp_dict[key]}")
40
+ return result
41
+
42
+
43
+ def config_checking_print(msg):
44
+ logger.info(f"[config checking log] {msg}")
45
+
46
+
47
+ def tensor_to_hash(tensor):
48
+ """Compute the hash value of a tensor"""
49
+ tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes()
50
+ return bytes_hash(tensor_bytes)
51
+
52
+
53
+ def get_tensor_features(tensor):
54
+ features = {
55
+ "max": FmkAdp.tensor_max(tensor),
56
+ "min": FmkAdp.tensor_min(tensor),
57
+ "mean": FmkAdp.tensor_mean(tensor),
58
+ "norm": FmkAdp.tensor_norm(tensor),
59
+ }
60
+
61
+ return features
62
+
63
+
64
+ def compare_dicts(dict1, dict2, path=''):
65
+ deleted = []
66
+ added = []
67
+ changed = []
68
+ result = {}
69
+
70
+ for key in dict1:
71
+ if key not in dict2:
72
+ deleted.append(f"[Deleted]: {path + key}")
73
+ result[key] = "[deleted]"
74
+ else:
75
+ if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
76
+ sub_deleted, sub_added, sub_changed, sub_result = compare_dicts(
77
+ dict1[key], dict2[key], path + key + '/')
78
+ deleted.extend(sub_deleted)
79
+ added.extend(sub_added)
80
+ changed.extend(sub_changed)
81
+ if sub_result:
82
+ result[key] = sub_result
83
+ elif dict1[key] != dict2[key]:
84
+ changed.append(f"[Changed]: {path + key} : {dict1[key]} -> {dict2[key]}")
85
+ result[key] = f"[changed]: {dict1[key]} -> {dict2[key]}"
86
+ for key in dict2:
87
+ if key not in dict1:
88
+ added.append(f"[Added]: {path + key}")
89
+ result[key] = "[added]"
90
+ return deleted, added, changed, result
91
+
92
+
93
+ def bytes_hash(obj: bytes):
94
+ hex_dig = hashlib.sha256(obj).hexdigest()
95
+ short_hash = int(hex_dig, 16) % (2 ** 16)
96
+ return short_hash
97
+
98
+
99
+ def update_dict(ori_dict, new_dict):
100
+ for key, value in new_dict.items():
101
+ if key in ori_dict and ori_dict[key] != value:
102
+ if "values" in ori_dict.keys():
103
+ ori_dict[key]["values"].append(new_dict[key])
104
+ else:
105
+ ori_dict[key] = {"description": "duplicate_value", "values": [ori_dict[key], new_dict[key]]}
106
+ else:
107
+ ori_dict[key] = value
@@ -0,0 +1,239 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from typing import Dict, Any, Optional, Callable, Union, List, Tuple
18
+
19
+ from msprobe.core.common.const import Const
20
+ from msprobe.core.common.file_utils import load_yaml
21
+ from msprobe.core.common.log import logger
22
+
23
+
24
+ def _get_attr(module, attr_name):
25
+ if Const.SEP in attr_name:
26
+ sub_module_name, sub_attr = attr_name.rsplit(Const.SEP, 1)
27
+ sub_module = getattr(module, sub_module_name, None)
28
+ attr = getattr(sub_module, sub_attr, None)
29
+ else:
30
+ attr = getattr(module, attr_name, None)
31
+ return attr
32
+
33
+
34
+ class ApiWrapper:
35
+ def __init__(
36
+ self, api_types: Dict[str, Dict[str, Any]],
37
+ api_list_paths: Union[str, List[str], Tuple[str]],
38
+ backlist: Union[List[str], Tuple[str]] = None
39
+ ):
40
+ self.api_types = api_types
41
+ if not isinstance(api_list_paths, (list, tuple)):
42
+ api_list_paths = [api_list_paths] * len(self.api_types)
43
+ elif len(api_list_paths) != len(self.api_types):
44
+ raise RuntimeError("The number of api_list_paths must be equal to the number of frameworks in 'api_types', "
45
+ "when api_list_paths is a list or tuple.")
46
+ self.api_list_paths = api_list_paths
47
+ self.backlist = backlist if backlist else []
48
+ self.api_names = self._get_api_names()
49
+ self.wrapped_api_functions = dict()
50
+
51
+ @staticmethod
52
+ def deal_with_self_kwargs(api_name, api_func, args, kwargs):
53
+ if kwargs and 'self' in kwargs:
54
+ func_params = None
55
+ try:
56
+ func_params = inspect.signature(api_func).parameters
57
+ except Exception:
58
+ if api_name in Const.API_WITH_SELF_ARG:
59
+ func_params = inspect.signature(Const.API_WITH_SELF_ARG.get(api_name)).parameters
60
+ if func_params is None:
61
+ return False, args, kwargs
62
+
63
+ for name, param in func_params.items():
64
+ if name == 'self' and param.kind == inspect.Parameter.KEYWORD_ONLY:
65
+ return False, args, kwargs
66
+ args_ = list(args)
67
+ names_and_values = []
68
+ self_index = 0
69
+ for i, item in enumerate(func_params.items()):
70
+ names_and_values.append((item[0], item[1].default))
71
+ if item[0] == 'self':
72
+ self_index = i
73
+ break
74
+ for i in range(len(args), self_index + 1):
75
+ if names_and_values[i][0] in kwargs:
76
+ args_.append(kwargs.pop(names_and_values[i][0]))
77
+ else:
78
+ args_.append(names_and_values[i][1])
79
+ args = tuple(args_)
80
+
81
+ return True, args, kwargs
82
+
83
+ def wrap_api(
84
+ self, api_templates, hook_build_func: Optional[Callable]
85
+ ):
86
+ api_types_num = sum([len(v) for v in self.api_types.values()])
87
+ if not isinstance(api_templates, (list, tuple)):
88
+ api_templates = [api_templates] * api_types_num
89
+ elif len(api_templates) != api_types_num:
90
+ raise RuntimeError("The number of api_templates must be equal to the number of api_types, "
91
+ "when api_templates is a list or tuple.")
92
+
93
+ self.wrapped_api_functions.clear()
94
+ index = 0
95
+ for framework, api_types in self.api_types.items():
96
+ wrapped_functions_in_framework = dict()
97
+ for api_type, api_modules in api_types.items():
98
+ wrapped_functions = dict()
99
+ name_prefix = Const.API_DATA_PREFIX.get(framework, {}).get(api_type, "API")
100
+ api_template = api_templates[index]
101
+ index += 1
102
+ for api_name in self.api_names.get(framework, {}).get(api_type, []):
103
+ ori_api = _get_attr(api_modules[0], api_name)
104
+ if callable(ori_api):
105
+ def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template):
106
+ def api_function(*args, **kwargs):
107
+ api_name_with_prefix = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1])
108
+ enable_wrap, args, kwargs = self.deal_with_self_kwargs(api_name_with_prefix,
109
+ api_func, args, kwargs)
110
+ if not enable_wrap:
111
+ logger.warning(f'Cannot collect precision data of {api_name_with_prefix}. '
112
+ 'It may be fixed by passing the value of "self" '
113
+ 'as a positional argument instead of a keyword argument. ')
114
+ return api_func(*args, **kwargs)
115
+ return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs)
116
+ api_function.__name__ = api_name
117
+ return api_function
118
+ wrapped_functions[api_name] = wrap_api_func(api_name, ori_api, name_prefix,
119
+ hook_build_func, api_template)
120
+ wrapped_functions_in_framework[api_type] = wrapped_functions
121
+ self.wrapped_api_functions[framework] = wrapped_functions_in_framework
122
+ return self.wrapped_api_functions
123
+
124
+ def _get_api_names(self):
125
+ api_names = dict()
126
+
127
+ for index, framework in enumerate(self.api_types.keys()):
128
+ api_list = load_yaml(self.api_list_paths[index])
129
+ valid_names = dict()
130
+ for api_type, api_modules in self.api_types.get(framework, {}).items():
131
+ key_in_file = Const.SUPPORT_API_DICT_KEY_MAP.get(framework, {}).get(api_type)
132
+ api_from_file = api_list.get(key_in_file, [])
133
+ names = set()
134
+ for api_name in api_from_file:
135
+ if f'{key_in_file}.{api_name}' in self.backlist:
136
+ continue
137
+ target_attr = api_name
138
+ target_module = api_modules[0]
139
+ if Const.SEP in api_name:
140
+ sub_module_name, target_attr = api_name.rsplit(Const.SEP, 1)
141
+ target_module = getattr(api_modules[0], sub_module_name, None)
142
+ if target_module and target_attr in dir(target_module):
143
+ names.add(api_name)
144
+ valid_names[api_type] = names
145
+ api_names[framework] = valid_names
146
+
147
+ return api_names
148
+
149
+
150
+ class ApiRegistry:
151
+ """
152
+ Base class for api registry.
153
+ """
154
+
155
+ def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates, backlist=None):
156
+ self.ori_api_attr = dict()
157
+ self.wrapped_api_attr = dict()
158
+ self.inner_used_ori_attr = dict()
159
+ self.inner_used_wrapped_attr = dict()
160
+ self.api_types = api_types
161
+ self.inner_used_api = inner_used_api
162
+ self.supported_api_list_path = supported_api_list_path
163
+ self.api_templates = api_templates
164
+ self.backlist = backlist if backlist else []
165
+ self.all_api_registered = False
166
+
167
+ @staticmethod
168
+ def store_ori_attr(ori_api_group, api_list, api_ori_attr):
169
+ for api in api_list:
170
+ api_ori_attr[api] = _get_attr(ori_api_group, api)
171
+
172
+ @staticmethod
173
+ def set_api_attr(api_group, attr_dict):
174
+ for api, api_attr in attr_dict.items():
175
+ if Const.SEP in api:
176
+ sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
177
+ sub_module = getattr(api_group, sub_module_name, None)
178
+ if sub_module is not None:
179
+ setattr(sub_module, sub_op, api_attr)
180
+ else:
181
+ setattr(api_group, api, api_attr)
182
+
183
+ @staticmethod
184
+ def register_custom_api(module, api_name, api_prefix, hook_build_func, api_template):
185
+ def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template):
186
+ def api_function(*args, **kwargs):
187
+ return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs)
188
+
189
+ api_function.__name__ = api_name
190
+ return api_function
191
+
192
+ setattr(module, api_name,
193
+ wrap_api_func(api_name, getattr(module, api_name), api_prefix, hook_build_func, api_template))
194
+
195
+ def register_all_api(self):
196
+ self.all_api_registered = True
197
+ for framework, api_types in self.api_types.items():
198
+ for api_type, api_modules in api_types.items():
199
+ api_type_with_framework = framework + Const.SEP + api_type
200
+ for module in api_modules[1]:
201
+ self.set_api_attr(module, self.wrapped_api_attr.get(api_type_with_framework, {}))
202
+
203
+ def register_inner_used_api(self):
204
+ for api_type in self.inner_used_api.keys():
205
+ self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_wrapped_attr.get(api_type, {}))
206
+
207
+ def restore_all_api(self):
208
+ self.all_api_registered = False
209
+ for framework, api_types in self.api_types.items():
210
+ for api_type, api_modules in api_types.items():
211
+ api_type_with_framework = framework + Const.SEP + api_type
212
+ for module in api_modules[1]:
213
+ self.set_api_attr(module, self.ori_api_attr.get(api_type_with_framework, {}))
214
+
215
+ def restore_inner_used_api(self):
216
+ for api_type in self.inner_used_api.keys():
217
+ self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_ori_attr.get(api_type, {}))
218
+
219
+ def initialize_hook(self, hook_build_func):
220
+ api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path, self.backlist)
221
+ wrapped_api_functions = api_wrapper.wrap_api(self.api_templates, hook_build_func)
222
+
223
+ for framework, api_types in self.api_types.items():
224
+ for api_type, api_modules in api_types.items():
225
+ ori_attr = dict()
226
+ self.store_ori_attr(api_modules[0], api_wrapper.api_names.get(framework).get(api_type), ori_attr)
227
+ api_type_with_framework = framework + Const.SEP + api_type
228
+ self.ori_api_attr[api_type_with_framework] = ori_attr
229
+ self.wrapped_api_attr[api_type_with_framework] = wrapped_api_functions.get(framework).get(api_type)
230
+
231
+ for inner_used_api_type, inner_used_api_list in self.inner_used_api.items():
232
+ ori_attr = dict()
233
+ wrapped_attr = dict()
234
+ for api_name in inner_used_api_list[1:]:
235
+ if self.ori_api_attr.get(inner_used_api_type, {}).get(api_name):
236
+ ori_attr[api_name] = self.ori_api_attr.get(inner_used_api_type).get(api_name)
237
+ wrapped_attr[api_name] = self.wrapped_api_attr.get(inner_used_api_type).get(api_name)
238
+ self.inner_used_ori_attr[inner_used_api_type] = ori_attr
239
+ self.inner_used_wrapped_attr[inner_used_api_type] = wrapped_attr
@@ -41,7 +41,7 @@ class DataCollector:
41
41
  self.backward_module_names = {}
42
42
  self.optimizer_status = ""
43
43
  self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
44
- atexit.register(self.write_json)
44
+ atexit.register(self.write_json_at_exit)
45
45
 
46
46
  @property
47
47
  def dump_data_dir(self):
@@ -78,6 +78,11 @@ class DataCollector:
78
78
  def write_json(self):
79
79
  self.data_writer.write_json()
80
80
 
81
+ def write_json_at_exit(self):
82
+ if self.config.async_dump and self.config.task == Const.TENSOR:
83
+ self.data_processor.dump_async_data()
84
+ self.data_writer.write_json()
85
+
81
86
  def update_data(self, name, data_info):
82
87
  msg = f"msprobe is collecting data on {name}."
83
88
  if self.config.task == Const.OVERFLOW_CHECK:
@@ -89,6 +94,10 @@ class DataCollector:
89
94
  logger.debug(msg)
90
95
  self.data_writer.update_data(data_info)
91
96
 
97
+ def call_stack_collect(self, name):
98
+ stack_info = self.data_processor.analyze_api_call_stack(name)
99
+ self.data_writer.update_stack(name, stack_info)
100
+
92
101
  def forward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
93
102
  if self.config.task == Const.FREE_BENCHMARK:
94
103
  backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
@@ -118,9 +127,16 @@ class DataCollector:
118
127
  self.set_is_recomputable(data_info, is_recompute)
119
128
  if self.config.level == Const.LEVEL_L2:
120
129
  return
121
- self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
130
+ self.call_stack_collect(name)
122
131
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
123
132
 
133
+ def forward_data_collect_only_tensor(self, name, module, pid, module_input_output):
134
+ if not self.check_scope_and_pid(self.scope, name, pid):
135
+ return
136
+
137
+ self.data_processor.analyze_forward(name, module, module_input_output)
138
+
139
+
124
140
  def forward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
125
141
  self.update_construct(name)
126
142
  if not self.check_scope_and_pid(self.scope, name, pid):
@@ -130,9 +146,15 @@ class DataCollector:
130
146
  if self.config.task != Const.STRUCTURE:
131
147
  data_info = self.data_processor.analyze_forward(name, module, module_input_output)
132
148
  self.set_is_recomputable(data_info, is_recompute)
133
- self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
149
+ self.call_stack_collect(name)
134
150
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
135
151
 
152
+ def backward_data_collect_only_tensor(self, name, module, pid, module_input_output, is_recompute=None):
153
+ if not self.check_scope_and_pid(self.scope, name, pid):
154
+ return
155
+
156
+ self.data_processor.analyze_backward(name, module, module_input_output)
157
+
136
158
  def backward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
137
159
  self.update_construct(name)
138
160
  if not self.check_scope_and_pid(self.scope, name, pid):
@@ -180,7 +202,10 @@ class DataCollector:
180
202
  self.optimizer_status_first_start[self.optimizer_status] = False
181
203
  self.data_writer.update_construct({name: self.optimizer_status})
182
204
  else:
183
- self.data_writer.update_construct({name: self.module_processor.api_parent_node})
205
+ if self.config.level == Const.LEVEL_MIX and \
206
+ not (name.startswith(Const.MODULE) or name.startswith(Const.CELL)):
207
+ self.data_writer.update_construct({name: self.module_processor.api_parent_node})
208
+
184
209
  self.data_writer.update_construct(self.module_processor.module_node)
185
210
 
186
211
  def handle_data(self, name, data_info, flush=False):
@@ -204,6 +229,7 @@ class DataCollector:
204
229
 
205
230
  def params_data_collect(self, name, param_name, pid, data):
206
231
  grad_name = name + Const.SEP + Const.PARAMS_GRAD
232
+ self.update_api_or_module_name(grad_name)
207
233
  # 校验scope和pid,以及当前name是否有过反向计算
208
234
  if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
209
235
  # 如果没有反向计算,则需要清除之前占位写入的grad数据
@@ -213,18 +239,19 @@ class DataCollector:
213
239
  data_info = self.data_processor.analyze_params(grad_name, param_name, data)
214
240
  self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
215
241
 
216
- def fill_stack_tensor_data(self):
217
- self.data_writer.fill_stack_tensor_data()
218
242
 
219
243
  def debug_data_collect_forward(self, variable, name_with_count):
220
244
 
221
245
  data_info = self.data_processor.analyze_debug_forward(variable, name_with_count)
222
- self.data_writer.update_debug({name_with_count: data_info})
246
+ name_with_count_category = name_with_count + Const.SEP + Const.DEBUG
247
+ self.data_writer.update_debug({name_with_count_category: data_info})
223
248
 
224
249
  def debug_data_collect_backward(self, variable, grad_name_with_count):
225
250
  # prepare all None nested data structure
226
251
  all_none_data_info = self.data_processor.analyze_element_to_all_none(variable)
227
- self.data_writer.update_debug({grad_name_with_count: all_none_data_info})
252
+ grad_name_with_count_category = grad_name_with_count + Const.SEP + Const.DEBUG
253
+ self.data_writer.update_debug({grad_name_with_count_category: all_none_data_info})
228
254
 
229
255
  # register tensor backward hook
230
- self.data_processor.analyze_debug_backward(variable, grad_name_with_count, self.data_writer.cache_debug['data'])
256
+ self.data_processor.analyze_debug_backward(variable, grad_name_with_count_category,
257
+ self.data_writer.cache_debug['data'])