mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,451 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ # 标准库
19
+ import argparse
20
+ import json
21
+ import os
22
+ import re
23
+ import string
24
+
25
+ # 应用程序自定义模块
26
+ from msprobe.core.common.file_utils import (
27
+ FileOpen,
28
+ load_json,
29
+ save_json,
30
+ make_dir,
31
+ change_mode,
32
+ )
33
+ from msprobe.core.common.utils import (
34
+ check_file_or_directory_path,
35
+ check_op_str_pattern_valid,
36
+ is_int,
37
+ )
38
+ from msprobe.core.common.const import Const, MonitorConst, MsgConst, FileCheckConst
39
+ from msprobe.core.common.log import logger
40
+ from msprobe.core.common.decorator import recursion_depth_decorator
41
+
42
+ OPERATOR_TYPE = ("Functional", "Tensor", "Torch", "Mint")
43
+
44
+ API_INFO = 2
45
+ FOUR_SEGMENT = 4
46
+ FIVE_SEGMENT = 5
47
+ DATA_NAME = "data_name"
48
+ API_MAX_LENGTH = 30
49
+ PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD]
50
+ DATAMODE_LIST = ["random_data", "real_data"]
51
+ ITER_MAX_TIMES = 1000
52
+ FRAMEWORK = 'framework'
53
+ REAL_DATA_PATH = 'real_data_path'
54
+ EXCLUED = {FRAMEWORK, REAL_DATA_PATH}
55
+
56
+
57
+ class APIInfo:
58
+ def __init__(self, api_full_name, api_info_dict, backward_info=None):
59
+ self.api_full_name = api_full_name
60
+ self.api_info_dict = api_info_dict
61
+ self.backward_info = backward_info
62
+
63
+ @property
64
+ def api_type(self):
65
+ return self.api_full_name.split(Const.SEP, -1)[0]
66
+
67
+ @classmethod
68
+ def from_json(cls, json_content, propagation):
69
+ forward_name, forward_dict = list(json_content.items())[0]
70
+ forward_info = cls(api_full_name=forward_name, api_info_dict=forward_dict)
71
+
72
+ if propagation == Const.BACKWARD:
73
+ backward_name, backward_dict = list(json_content.items())[1]
74
+ backward_info = cls(api_full_name=backward_name, api_info_dict=backward_dict)
75
+ forward_info.backward_info = backward_info
76
+
77
+ if not forward_info.is_supported_type():
78
+ raise ValueError(f"type {forward_info.api_type} of API is not supported!")
79
+
80
+ return forward_info
81
+
82
+ def is_supported_type(self):
83
+ return self.api_type in OPERATOR_TYPE
84
+
85
+
86
+ class CommonConfig:
87
+ def __init__(self, json_config):
88
+ self.dump_json_path = json_config.get('dump_json_path')
89
+ self.api_name = json_config.get('api_name')
90
+ self.extract_api_path = json_config.get('extract_api_path')
91
+ self.propagation = json_config.get('propagation')
92
+ self.data_mode = json_config.get('data_mode')
93
+ self.random_seed = json_config.get('random_seed')
94
+ self.iter_times = json_config.get('iter_times')
95
+ self._check_config()
96
+
97
+ def check_user_settings(self):
98
+ iter_t = self.iter_times
99
+ if iter_t <= 0 or iter_t > ITER_MAX_TIMES:
100
+ raise ValueError(f"iter_times should be range from 1 to {ITER_MAX_TIMES}.")
101
+
102
+ json_file = self.extract_api_path
103
+ propagation = self.propagation
104
+
105
+ json_content = load_json(json_file)
106
+
107
+ # ensure the dict is not empty
108
+ if not json_content:
109
+ raise ValueError(f'json file is empty!')
110
+
111
+ # ensure json_content is of type dict
112
+ if not isinstance(json_content, dict):
113
+ raise ValueError(f'content of json file is not a dict!')
114
+
115
+ # ensure the length of json_content is within allowed limits
116
+
117
+ filtered = {k: v for k, v in json_content.items() if k not in EXCLUED}
118
+
119
+ if len(filtered) > API_INFO:
120
+ raise ValueError(f'json file has more than one API, the API only contains forward and backward info')
121
+
122
+ if propagation == Const.FORWARD and filtered and all(k.endswith('forward') for k in filtered):
123
+ raise ValueError(
124
+ "json file has more than one API, the API only contains forward info。"
125
+ )
126
+
127
+ # Retrieve the first API name and dictionary
128
+ forward_item = next(iter(json_content.items()), None)
129
+ if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]:
130
+ raise ValueError(f'Invalid forward API data in json_content!')
131
+
132
+ # if propagation is backward, ensure json file contains forward and backward info
133
+ if propagation == Const.BACKWARD and len(filtered) < API_INFO:
134
+ raise ValueError(f'Backward propagation requires contains forward and backward info!')
135
+
136
+ # if propagation is backward, ensure it has valid data
137
+ if propagation == Const.BACKWARD:
138
+ backward_item = list(json_content.items())[1]
139
+ if not isinstance(backward_item[1], dict) or not backward_item[1]:
140
+ raise ValueError(f'Invalid backward API data in json_content!')
141
+
142
+ return json_content
143
+
144
+ def _check_config(self):
145
+ if self.dump_json_path:
146
+ check_file_or_directory_path(self.dump_json_path)
147
+ if self.api_name:
148
+ check_op_str_pattern_valid(self.api_name)
149
+ if len(self.api_name) > API_MAX_LENGTH:
150
+ raise ValueError(f'API name {self.api_name} is too long!')
151
+ make_dir(os.path.dirname(self.extract_api_path))
152
+ if self.propagation and self.propagation not in PROPAGATION_LIST:
153
+ raise ValueError(f'propagation is invalid, it should be one of {PROPAGATION_LIST}')
154
+ if self.data_mode and self.data_mode not in DATAMODE_LIST:
155
+ raise ValueError(f'data_mode is invalid, it should be one of {DATAMODE_LIST}')
156
+ if not is_int(self.random_seed):
157
+ raise ValueError(f'random_seed is invalid, it should be an int')
158
+ if not is_int(self.iter_times):
159
+ raise ValueError(f'iter_times is invalid, it should be an int')
160
+
161
+
162
+ class APIExtractor:
163
+ def __init__(self, api_name, dump_json_path, output_file):
164
+ self.api_name = api_name
165
+ self.dump_json_path = dump_json_path
166
+ self.output_file = output_file
167
+ self.data = None
168
+ self.framework = None
169
+ self.real_data_path = None
170
+
171
+ def extract_op(self):
172
+ self.data = load_json(self.dump_json_path)
173
+ # 拿到 framework
174
+ self.framework = self.data.get(FRAMEWORK, None)
175
+
176
+ new_data = {}
177
+ extract_key_pattern = re.compile(f"^{re.escape(self.api_name)}\..+") # 修改为只要包含或等于apiname即可,不需要是只包含
178
+
179
+ self.real_data_path = self.data.get('dump_data_dir', '')
180
+
181
+ for key, value in self.data.get('data', {}).items():
182
+ if extract_key_pattern.match(key):
183
+ if self.real_data_path:
184
+ value = self.load_real_data_path(value, self.real_data_path)
185
+ new_data[key] = value
186
+
187
+ if self.real_data_path is not None:
188
+ new_data[REAL_DATA_PATH] = self.real_data_path
189
+
190
+ # 把 framework 加进去
191
+ if self.framework is not None:
192
+ new_data[FRAMEWORK] = self.framework
193
+ if not new_data:
194
+ logger.warning(f"Warning: The api '{self.api_name}' does not exist in the file.")
195
+ else:
196
+ save_json(self.output_file, new_data, indent=4)
197
+ logger.info(
198
+ f"The api '{self.api_name}' has been successfully extracted and saved in: {self.output_file}")
199
+
200
+ def load_real_data_path(self, value, dump_data_dir):
201
+ parameters = [Const.INPUT_ARGS, Const.GRAD_INPUT, Const.INPUT, Const.OUTPUT, Const.GRAD_OUTPUT]
202
+ for parameter in parameters:
203
+ for v in value.get(parameter, []):
204
+ if v is not None:
205
+ self.update_data_name(v, dump_data_dir)
206
+ return value
207
+
208
+ @recursion_depth_decorator("OpGenerator: APIExtractor.update_data_name")
209
+ def update_data_name(self, data, dump_data_dir):
210
+ if isinstance(data, list):
211
+ for item in data:
212
+ self.update_data_name(item, dump_data_dir)
213
+ elif DATA_NAME in data:
214
+ data[DATA_NAME] = os.path.join(dump_data_dir, data[DATA_NAME])
215
+
216
+
217
+ class OperatorScriptGenerator:
218
+ def __init__(self, common_config, args_info_forward, kwargs_info_forward, args_info_backward):
219
+ self.common_config = common_config
220
+ self.args_info_forward = args_info_forward
221
+ self.kwargs_info_forward = kwargs_info_forward
222
+ self.args_info_backward = args_info_backward
223
+
224
+ @staticmethod
225
+ def extract_detailed_api_segments(full_api_name):
226
+ """
227
+ Function Description:
228
+ Extract the name of the API.
229
+ Parameter:
230
+ full_api_name_with_direction_status: Full name of the API. Example: torch.matmul.0.forward.output.0
231
+ Return:
232
+ api_name: Name of api. Example: matmul, mul, etc.
233
+ full_api_name: Full name of api. Example: torch.matmul.0
234
+ direction_status: Direction status of api. Example: forward, backward, etc.
235
+ """
236
+ api_parts = full_api_name.split(Const.SEP)
237
+ api_parts_length = len(api_parts)
238
+ api_type, api_name, api_order = None, None, None
239
+ if api_parts_length == FOUR_SEGMENT:
240
+ api_type, api_name, api_order, _ = api_parts
241
+ elif api_parts_length == FIVE_SEGMENT:
242
+ api_type, prefix, api_name, api_order, _ = api_parts
243
+ api_name = Const.SEP.join([prefix, api_name])
244
+ return api_type, api_name, api_order
245
+
246
+ @staticmethod
247
+ def generate_forward_inputs_code(args_info):
248
+ names = []
249
+
250
+ def collect(info):
251
+ if isinstance(info, dict):
252
+ names.append(info["parameter_name"])
253
+ else:
254
+ for sub in info:
255
+ collect(sub)
256
+
257
+ collect(args_info)
258
+
259
+ return (
260
+ " forward_inputs = [\n"
261
+ " ComputeElement(parameter=info)\n"
262
+ " for info in (" + ", ".join(names) + ")\n"
263
+ " ]\n"
264
+ )
265
+
266
+ @staticmethod
267
+ def generate_kwargs_compute_element_dict_code():
268
+ return (
269
+ " # ---- 构造 kwargs 对应的 ComputeElement 字典 ----\n"
270
+ " kwargs_compute_element_dict = {\n"
271
+ " key_str: ComputeElement(compute_element_info=compute_element_info)\n"
272
+ " for key_str, compute_element_info in kwargs_device.items()\n"
273
+ " }\n"
274
+ )
275
+
276
+ @staticmethod
277
+ def generate_gradient_inputs_code(args_info_backward):
278
+ names = []
279
+
280
+ def collect(info):
281
+ if isinstance(info, dict):
282
+ names.append(info["parameter_name"])
283
+ else:
284
+ for sub in info:
285
+ collect(sub)
286
+
287
+ collect(args_info_backward)
288
+
289
+ return (
290
+ " # —— 构造反向梯度 ComputeElement 列表 —— #\n"
291
+ " gradient_inputs = [\n"
292
+ " ComputeElement(parameter=info)\n"
293
+ " for info in (" + ", ".join(names) + ")\n"
294
+ " ]\n"
295
+ )
296
+
297
+ def get_settings(self, api_full_name):
298
+ '''
299
+ internal_settings contain all information needed for the operator program.
300
+ keys:
301
+ api_full_name: api_type.api_name.ordinal_number
302
+ api_type: type of API, one of torch.nn.functional, torch.Tensor or Torch
303
+ api_name: name of API
304
+ ordinal_number: how many times the same api has been called
305
+ direction_status: forward
306
+ random_seed: if mode is random_data, random seed is random_seed
307
+ iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data,
308
+ iter_times does not matter
309
+ args_element_assignment: code for args assignment
310
+ args_list_generator_device: code for generate args list on device
311
+ args_list_generator_bench: code for generate args list on bench
312
+ kwargs_value_assignment: code for kwargs assignment
313
+ kwargs_dict_generator_device: code for generate kwargs dict on device
314
+ kwargs_dict_generator_bench: code for generate kwargs dict on bench
315
+ '''
316
+ # Generate an internal setting dictionary based on user settings
317
+ # including API name, type, comparison standard, random seed, number of iterations and other information
318
+ internal_settings = {}
319
+ internal_settings["propagation"] = self.common_config.propagation
320
+ internal_settings["api_full_name"] = api_full_name
321
+ api_type, api_name, ordinal_number = self.extract_detailed_api_segments(api_full_name)
322
+ if api_type == "Functional":
323
+ internal_settings["api_type"] = "torch.nn.functional"
324
+ elif api_type == "Tensor":
325
+ internal_settings["api_type"] = "torch.Tensor"
326
+ else:
327
+ internal_settings["api_type"] = "torch"
328
+ internal_settings["api_name"] = api_name
329
+ internal_settings["ordinal_number"] = ordinal_number
330
+ internal_settings["direction_status"] = self.common_config.propagation
331
+ internal_settings["random_seed"] = self.common_config.random_seed
332
+ internal_settings["data_mode"] = self.common_config.data_mode
333
+ if self.common_config.data_mode == "real_data":
334
+ internal_settings["iter_times"] = 1
335
+ else:
336
+ internal_settings["iter_times"] = self.common_config.iter_times
337
+
338
+ internal_settings["args_info_forward"] = self.args_info_forward
339
+ internal_settings["kwargs_info_forward"] = self.kwargs_info_forward
340
+ internal_settings["args_info_backward"] = self.args_info_backward
341
+
342
+ return internal_settings
343
+
344
+
345
+ def _op_generator_parser(parser):
346
+ parser.add_argument("-i", "--config_input", dest="config_input", type=str,
347
+ help="<Required> Path of config json file", required=True)
348
+ parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str,
349
+ help="<Required> Path of extract api_name.json.", required=True)
350
+
351
+
352
+ def parse_json_config(json_file_path):
353
+ if not json_file_path:
354
+ raise Exception("config_input path can not be empty, please check.")
355
+ json_config = load_json(json_file_path)
356
+ common_config = CommonConfig(json_config)
357
+ return common_config
358
+
359
+
360
+ def _run_operator_generate_commond(cmd_args):
361
+ common_config = parse_json_config(cmd_args.config_input)
362
+
363
+ if common_config.dump_json_path:
364
+ api_extract = APIExtractor(common_config.api_name, common_config.dump_json_path, common_config.extract_api_path)
365
+ api_extract.extract_op()
366
+ framework = api_extract.framework
367
+ real_data_path = api_extract.real_data_path
368
+ check_file_or_directory_path(common_config.extract_api_path)
369
+ check_file_or_directory_path(cmd_args.api_output_path, isdir=True)
370
+ json_content = common_config.check_user_settings()
371
+ api_info = APIInfo.from_json(json_content, common_config.propagation)
372
+
373
+ if common_config.propagation == Const.BACKWARD:
374
+ # read and check json
375
+ api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict
376
+ api_full_name_backward, api_info_dict_backward = (api_info.backward_info.api_full_name,
377
+ api_info.backward_info.api_info_dict)
378
+ args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS)
379
+ kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS)
380
+ if Const.GRAD_INPUT in api_info_dict_backward:
381
+ args_info_backward = api_info_dict_backward.get(Const.GRAD_INPUT)
382
+ elif Const.INPUT in api_info_dict_backward:
383
+ args_info_backward = api_info_dict_backward.get(Const.INPUT)
384
+ op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, args_info_backward)
385
+ internal_settings = op_generate.get_settings(api_full_name_backward)
386
+ internal_settings[FRAMEWORK] = framework
387
+ internal_settings[REAL_DATA_PATH] = real_data_path
388
+ else:
389
+ # read and check json
390
+ api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict
391
+
392
+ args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS)
393
+
394
+ kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS)
395
+
396
+ op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, None)
397
+ internal_settings = op_generate.get_settings(api_full_name_forward)
398
+ internal_settings[FRAMEWORK] = framework
399
+ internal_settings[REAL_DATA_PATH] = real_data_path
400
+
401
+ template_path = os.path.join(os.path.dirname(__file__), "operator_replication.template")
402
+ operator_script_path = os.path.join(cmd_args.api_output_path,
403
+ "{0}.py".format(internal_settings.get("api_full_name")))
404
+
405
+ class SafeDict(dict):
406
+ def __missing__(self, key):
407
+ # leave {key} in the output if it’s not in the dict
408
+ return '{' + key + '}'
409
+
410
+ class RobustFormatter(string.Formatter):
411
+ def vformat(self, format_string, args, kwargs):
412
+ result = []
413
+ # parse() 会把文本和每个占位符拆开
414
+ for literal, field_name, format_spec, conversion in self.parse(format_string):
415
+ # 输出字面文本
416
+ result.append(literal)
417
+ if field_name is None:
418
+ continue
419
+ try:
420
+ # 正常获取变量并格式化
421
+ obj, _ = self.get_field(field_name, args, kwargs)
422
+ if conversion:
423
+ obj = self.convert_field(obj, conversion)
424
+ result.append(self.format_field(obj, format_spec))
425
+ except Exception:
426
+ # 不管是 KeyError 还是 ValueError,都原样回写 {field_name[:format_spec]}
427
+ placeholder = '{' + field_name
428
+ if conversion:
429
+ placeholder += '!' + conversion
430
+ if format_spec:
431
+ placeholder += ':' + format_spec
432
+ placeholder += '}'
433
+ result.append(placeholder)
434
+ return ''.join(result)
435
+
436
+ fmt = RobustFormatter()
437
+ with FileOpen(template_path, 'r') as ftemp, FileOpen(operator_script_path, 'w') as fout:
438
+ code_template = ftemp.read()
439
+ # 这里用 fmt.format,不用 format_map
440
+ fout.write(fmt.format(code_template, **internal_settings))
441
+
442
+ change_mode(operator_script_path, FileCheckConst.DATA_FILE_AUTHORITY)
443
+
444
+ logger.info(f"Generate operator script successfully and the name is {operator_script_path}.")
445
+
446
+
447
+ if __name__ == "__main__":
448
+ parser = argparse.ArgumentParser()
449
+ _op_generator_parser(parser)
450
+ cmd_args = parser.parse_args()
451
+ _run_operator_generate_commond(cmd_args)