mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,460 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ # 标准库
19
+ import argparse
20
+ import json
21
+ import os
22
+ import re
23
+ import string
24
+
25
+ # 应用程序自定义模块
26
+ from msprobe.core.common.file_utils import (
27
+ FileOpen,
28
+ load_json,
29
+ save_json,
30
+ make_dir,
31
+ change_mode,
32
+ )
33
+ from msprobe.core.common.utils import (
34
+ check_file_or_directory_path,
35
+ check_op_str_pattern_valid,
36
+ is_int,
37
+ )
38
+ from msprobe.core.common.const import Const, MonitorConst, MsgConst, FileCheckConst
39
+ from msprobe.core.common.log import logger
40
+ from msprobe.core.common.decorator import recursion_depth_decorator
41
+
42
+ OPERATOR_TYPE = ("Functional", "Tensor", "Torch", "Mint")
43
+
44
+ API_INFO = 2
45
+ FOUR_SEGMENT = 4
46
+ FIVE_SEGMENT = 5
47
+ DATA_NAME = "data_name"
48
+ API_MAX_LENGTH = 30
49
+ PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD]
50
+ DATAMODE_LIST = ["random_data", "real_data"]
51
+ ITER_MAX_TIMES = 1000
52
+ FRAMEWORK = 'framework'
53
+ REAL_DATA_PATH = 'real_data_path'
54
+ EXCLUED = {FRAMEWORK, REAL_DATA_PATH}
55
+
56
+
57
+ class APIInfo:
58
+ def __init__(self, api_full_name, api_info_dict, backward_info=None):
59
+ self.api_full_name = api_full_name
60
+ self.api_info_dict = api_info_dict
61
+ self.backward_info = backward_info
62
+
63
+ @property
64
+ def api_type(self):
65
+ return self.api_full_name.split(Const.SEP, -1)[0]
66
+
67
+ @classmethod
68
+ def from_json(cls, json_content, propagation):
69
+ forward_name, forward_dict = list(json_content.items())[0]
70
+ forward_info = cls(api_full_name=forward_name, api_info_dict=forward_dict)
71
+
72
+ if propagation == Const.BACKWARD:
73
+ backward_name, backward_dict = list(json_content.items())[1]
74
+ backward_info = cls(api_full_name=backward_name, api_info_dict=backward_dict)
75
+ forward_info.backward_info = backward_info
76
+
77
+ if not forward_info.is_supported_type():
78
+ raise ValueError(f"type {forward_info.api_type} of API is not supported!")
79
+
80
+ return forward_info
81
+
82
+ def is_supported_type(self):
83
+ return self.api_type in OPERATOR_TYPE
84
+
85
+
86
+ class CommonConfig:
87
+ def __init__(self, json_config):
88
+ self.dump_json_path = json_config.get('dump_json_path')
89
+ self.api_name = json_config.get('api_name')
90
+ self.extract_api_path = json_config.get('extract_api_path')
91
+ self.propagation = json_config.get('propagation')
92
+ self.data_mode = json_config.get('data_mode')
93
+ self.random_seed = json_config.get('random_seed')
94
+ self.iter_times = json_config.get('iter_times')
95
+ self._check_config()
96
+
97
+ def check_user_settings(self):
98
+ iter_t = self.iter_times
99
+ if iter_t <= 0 or iter_t > ITER_MAX_TIMES:
100
+ raise ValueError(f"iter_times should be range from 1 to {ITER_MAX_TIMES}.")
101
+
102
+ json_file = self.extract_api_path
103
+ propagation = self.propagation
104
+
105
+ json_content = load_json(json_file)
106
+
107
+ # ensure the dict is not empty
108
+ if not json_content:
109
+ raise ValueError(f'json file is empty!')
110
+
111
+ # ensure json_content is of type dict
112
+ if not isinstance(json_content, dict):
113
+ raise ValueError(f'content of json file is not a dict!')
114
+
115
+ # ensure the length of json_content is within allowed limits
116
+
117
+ filtered = {k: v for k, v in json_content.items() if k not in EXCLUED}
118
+
119
+ if not filtered:
120
+ raise ValueError(f'json file is empty!')
121
+
122
+ if len(filtered) > API_INFO:
123
+ raise ValueError(f'json file has more than one API, the API only contains forward and backward info')
124
+
125
+ is_forward_phase = propagation == Const.FORWARD
126
+
127
+ is_exact_api_count = len(filtered) == API_INFO
128
+
129
+ all_keys_forward = all(k.endswith('forward') for k in filtered)
130
+
131
+ if is_forward_phase and is_exact_api_count and all_keys_forward:
132
+ raise ValueError(
133
+ "json file has more than one API, the API only contains forward info。"
134
+ )
135
+
136
+ # Retrieve the first API name and dictionary
137
+ forward_item = next(iter(json_content.items()), None)
138
+ if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]:
139
+ raise ValueError(f'Invalid forward API data in json_content!')
140
+
141
+ # if propagation is backward, ensure json file contains forward and backward info
142
+ if propagation == Const.BACKWARD and len(filtered) < API_INFO:
143
+ raise ValueError(f'Backward propagation requires contains forward and backward info!')
144
+
145
+ # if propagation is backward, ensure it has valid data
146
+ if propagation == Const.BACKWARD:
147
+ backward_item = list(json_content.items())[1]
148
+ if not isinstance(backward_item[1], dict) or not backward_item[1]:
149
+ raise ValueError(f'Invalid backward API data in json_content!')
150
+
151
+ return json_content
152
+
153
+ def _check_config(self):
154
+ if self.dump_json_path:
155
+ check_file_or_directory_path(self.dump_json_path)
156
+ if self.api_name:
157
+ check_op_str_pattern_valid(self.api_name)
158
+ if len(self.api_name) > API_MAX_LENGTH:
159
+ raise ValueError(f'API name {self.api_name} is too long!')
160
+ make_dir(os.path.dirname(self.extract_api_path))
161
+ if self.propagation and self.propagation not in PROPAGATION_LIST:
162
+ raise ValueError(f'propagation is invalid, it should be one of {PROPAGATION_LIST}')
163
+ if self.data_mode and self.data_mode not in DATAMODE_LIST:
164
+ raise ValueError(f'data_mode is invalid, it should be one of {DATAMODE_LIST}')
165
+ if not is_int(self.random_seed):
166
+ raise ValueError(f'random_seed is invalid, it should be an int')
167
+ if not is_int(self.iter_times):
168
+ raise ValueError(f'iter_times is invalid, it should be an int')
169
+
170
+
171
+ class APIExtractor:
172
+ def __init__(self, api_name, dump_json_path, output_file):
173
+ self.api_name = api_name
174
+ self.dump_json_path = dump_json_path
175
+ self.output_file = output_file
176
+ self.data = None
177
+ self.framework = None
178
+ self.real_data_path = None
179
+
180
+ def extract_op(self):
181
+ self.data = load_json(self.dump_json_path)
182
+ # 拿到 framework
183
+ self.framework = self.data.get(FRAMEWORK, None)
184
+
185
+ new_data = {}
186
+ extract_key_pattern = re.compile(f"^{re.escape(self.api_name)}\..+") # 修改为只要包含或等于apiname即可,不需要是只包含
187
+
188
+ self.real_data_path = self.data.get('dump_data_dir', '')
189
+
190
+ for key, value in self.data.get('data', {}).items():
191
+ if extract_key_pattern.match(key):
192
+ if self.real_data_path:
193
+ value = self.load_real_data_path(value, self.real_data_path)
194
+ new_data[key] = value
195
+
196
+ if self.real_data_path is not None:
197
+ new_data[REAL_DATA_PATH] = self.real_data_path
198
+
199
+ # 把 framework 加进去
200
+ if self.framework is not None:
201
+ new_data[FRAMEWORK] = self.framework
202
+ if not new_data:
203
+ logger.warning(f"Warning: The api '{self.api_name}' does not exist in the file.")
204
+ else:
205
+ save_json(self.output_file, new_data, indent=4)
206
+ logger.info(
207
+ f"The api '{self.api_name}' has been successfully extracted and saved in: {self.output_file}")
208
+
209
+ def load_real_data_path(self, value, dump_data_dir):
210
+ parameters = [Const.INPUT_ARGS, Const.GRAD_INPUT, Const.INPUT, Const.OUTPUT, Const.GRAD_OUTPUT]
211
+ for parameter in parameters:
212
+ for v in value.get(parameter, []):
213
+ if v is not None:
214
+ self.update_data_name(v, dump_data_dir)
215
+ return value
216
+
217
+ @recursion_depth_decorator("OpGenerator: APIExtractor.update_data_name")
218
+ def update_data_name(self, data, dump_data_dir):
219
+ if isinstance(data, list):
220
+ for item in data:
221
+ self.update_data_name(item, dump_data_dir)
222
+ elif DATA_NAME in data:
223
+ data[DATA_NAME] = os.path.join(dump_data_dir, data[DATA_NAME])
224
+
225
+
226
+ class OperatorScriptGenerator:
227
+ def __init__(self, common_config, args_info_forward, kwargs_info_forward, args_info_backward):
228
+ self.common_config = common_config
229
+ self.args_info_forward = args_info_forward
230
+ self.kwargs_info_forward = kwargs_info_forward
231
+ self.args_info_backward = args_info_backward
232
+
233
+ @staticmethod
234
+ def extract_detailed_api_segments(full_api_name):
235
+ """
236
+ Function Description:
237
+ Extract the name of the API.
238
+ Parameter:
239
+ full_api_name_with_direction_status: Full name of the API. Example: torch.matmul.0.forward.output.0
240
+ Return:
241
+ api_name: Name of api. Example: matmul, mul, etc.
242
+ full_api_name: Full name of api. Example: torch.matmul.0
243
+ direction_status: Direction status of api. Example: forward, backward, etc.
244
+ """
245
+ api_parts = full_api_name.split(Const.SEP)
246
+ api_parts_length = len(api_parts)
247
+ api_type, api_name, api_order = None, None, None
248
+ if api_parts_length == FOUR_SEGMENT:
249
+ api_type, api_name, api_order, _ = api_parts
250
+ elif api_parts_length == FIVE_SEGMENT:
251
+ api_type, prefix, api_name, api_order, _ = api_parts
252
+ api_name = Const.SEP.join([prefix, api_name])
253
+ return api_type, api_name, api_order
254
+
255
+ @staticmethod
256
+ def generate_forward_inputs_code(args_info):
257
+ names = []
258
+
259
+ def collect(info):
260
+ if isinstance(info, dict):
261
+ names.append(info["parameter_name"])
262
+ else:
263
+ for sub in info:
264
+ collect(sub)
265
+
266
+ collect(args_info)
267
+
268
+ return (
269
+ " forward_inputs = [\n"
270
+ " ComputeElement(parameter=info)\n"
271
+ " for info in (" + ", ".join(names) + ")\n"
272
+ " ]\n"
273
+ )
274
+
275
+ @staticmethod
276
+ def generate_kwargs_compute_element_dict_code():
277
+ return (
278
+ " # ---- 构造 kwargs 对应的 ComputeElement 字典 ----\n"
279
+ " kwargs_compute_element_dict = {\n"
280
+ " key_str: ComputeElement(compute_element_info=compute_element_info)\n"
281
+ " for key_str, compute_element_info in kwargs_device.items()\n"
282
+ " }\n"
283
+ )
284
+
285
+ @staticmethod
286
+ def generate_gradient_inputs_code(args_info_backward):
287
+ names = []
288
+
289
+ def collect(info):
290
+ if isinstance(info, dict):
291
+ names.append(info["parameter_name"])
292
+ else:
293
+ for sub in info:
294
+ collect(sub)
295
+
296
+ collect(args_info_backward)
297
+
298
+ return (
299
+ " # —— 构造反向梯度 ComputeElement 列表 —— #\n"
300
+ " gradient_inputs = [\n"
301
+ " ComputeElement(parameter=info)\n"
302
+ " for info in (" + ", ".join(names) + ")\n"
303
+ " ]\n"
304
+ )
305
+
306
+ def get_settings(self, api_full_name):
307
+ '''
308
+ internal_settings contain all information needed for the operator program.
309
+ keys:
310
+ api_full_name: api_type.api_name.ordinal_number
311
+ api_type: type of API, one of torch.nn.functional, torch.Tensor or Torch
312
+ api_name: name of API
313
+ ordinal_number: how many times the same api has been called
314
+ direction_status: forward
315
+ random_seed: if mode is random_data, random seed is random_seed
316
+ iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data,
317
+ iter_times does not matter
318
+ args_element_assignment: code for args assignment
319
+ args_list_generator_device: code for generate args list on device
320
+ args_list_generator_bench: code for generate args list on bench
321
+ kwargs_value_assignment: code for kwargs assignment
322
+ kwargs_dict_generator_device: code for generate kwargs dict on device
323
+ kwargs_dict_generator_bench: code for generate kwargs dict on bench
324
+ '''
325
+ # Generate an internal setting dictionary based on user settings
326
+ # including API name, type, comparison standard, random seed, number of iterations and other information
327
+ internal_settings = {}
328
+ internal_settings["propagation"] = self.common_config.propagation
329
+ internal_settings["api_full_name"] = api_full_name
330
+ api_type, api_name, ordinal_number = self.extract_detailed_api_segments(api_full_name)
331
+ if api_type == "Functional":
332
+ internal_settings["api_type"] = "torch.nn.functional"
333
+ elif api_type == "Tensor":
334
+ internal_settings["api_type"] = "torch.Tensor"
335
+ else:
336
+ internal_settings["api_type"] = "torch"
337
+ internal_settings["api_name"] = api_name
338
+ internal_settings["ordinal_number"] = ordinal_number
339
+ internal_settings["direction_status"] = self.common_config.propagation
340
+ internal_settings["random_seed"] = self.common_config.random_seed
341
+ internal_settings["data_mode"] = self.common_config.data_mode
342
+ if self.common_config.data_mode == "real_data":
343
+ internal_settings["iter_times"] = 1
344
+ else:
345
+ internal_settings["iter_times"] = self.common_config.iter_times
346
+
347
+ internal_settings["args_info_forward"] = self.args_info_forward
348
+ internal_settings["kwargs_info_forward"] = self.kwargs_info_forward
349
+ internal_settings["args_info_backward"] = self.args_info_backward
350
+
351
+ return internal_settings
352
+
353
+
354
+ def _op_generator_parser(parser):
355
+ parser.add_argument("-i", "--config_input", dest="config_input", type=str,
356
+ help="<Required> Path of config json file", required=True)
357
+ parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str,
358
+ help="<Required> Path of extract api_name.json.", required=True)
359
+
360
+
361
+ def parse_json_config(json_file_path):
362
+ if not json_file_path:
363
+ raise Exception("config_input path can not be empty, please check.")
364
+ json_config = load_json(json_file_path)
365
+ common_config = CommonConfig(json_config)
366
+ return common_config
367
+
368
+
369
+ def _run_operator_generate_commond(cmd_args):
370
+ common_config = parse_json_config(cmd_args.config_input)
371
+
372
+ if common_config.dump_json_path:
373
+ api_extract = APIExtractor(common_config.api_name, common_config.dump_json_path, common_config.extract_api_path)
374
+ api_extract.extract_op()
375
+ framework = api_extract.framework
376
+ real_data_path = api_extract.real_data_path
377
+ check_file_or_directory_path(common_config.extract_api_path)
378
+ check_file_or_directory_path(cmd_args.api_output_path, isdir=True)
379
+ json_content = common_config.check_user_settings()
380
+ api_info = APIInfo.from_json(json_content, common_config.propagation)
381
+
382
+ if common_config.propagation == Const.BACKWARD:
383
+ # read and check json
384
+ api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict
385
+ api_full_name_backward, api_info_dict_backward = (api_info.backward_info.api_full_name,
386
+ api_info.backward_info.api_info_dict)
387
+ args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS)
388
+ kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS)
389
+ if Const.GRAD_INPUT in api_info_dict_backward:
390
+ args_info_backward = api_info_dict_backward.get(Const.GRAD_INPUT)
391
+ elif Const.INPUT in api_info_dict_backward:
392
+ args_info_backward = api_info_dict_backward.get(Const.INPUT)
393
+ op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, args_info_backward)
394
+ internal_settings = op_generate.get_settings(api_full_name_backward)
395
+ internal_settings[FRAMEWORK] = framework
396
+ internal_settings[REAL_DATA_PATH] = real_data_path
397
+ else:
398
+ # read and check json
399
+ api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict
400
+
401
+ args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS)
402
+
403
+ kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS)
404
+
405
+ op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, None)
406
+ internal_settings = op_generate.get_settings(api_full_name_forward)
407
+ internal_settings[FRAMEWORK] = framework
408
+ internal_settings[REAL_DATA_PATH] = real_data_path
409
+
410
+ template_path = os.path.join(os.path.dirname(__file__), "operator_replication.template")
411
+ operator_script_path = os.path.join(cmd_args.api_output_path,
412
+ "{0}.py".format(internal_settings.get("api_full_name")))
413
+
414
+ class SafeDict(dict):
415
+ def __missing__(self, key):
416
+ # leave {key} in the output if it’s not in the dict
417
+ return '{' + key + '}'
418
+
419
+ class RobustFormatter(string.Formatter):
420
+ def vformat(self, format_string, args, kwargs):
421
+ result = []
422
+ # parse() 会把文本和每个占位符拆开
423
+ for literal, field_name, format_spec, conversion in self.parse(format_string):
424
+ # 输出字面文本
425
+ result.append(literal)
426
+ if field_name is None:
427
+ continue
428
+ try:
429
+ # 正常获取变量并格式化
430
+ obj, _ = self.get_field(field_name, args, kwargs)
431
+ if conversion:
432
+ obj = self.convert_field(obj, conversion)
433
+ result.append(self.format_field(obj, format_spec))
434
+ except Exception:
435
+ # 不管是 KeyError 还是 ValueError,都原样回写 {field_name[:format_spec]}
436
+ placeholder = '{' + field_name
437
+ if conversion:
438
+ placeholder += '!' + conversion
439
+ if format_spec:
440
+ placeholder += ':' + format_spec
441
+ placeholder += '}'
442
+ result.append(placeholder)
443
+ return ''.join(result)
444
+
445
+ fmt = RobustFormatter()
446
+ with FileOpen(template_path, 'r') as ftemp, FileOpen(operator_script_path, 'w') as fout:
447
+ code_template = ftemp.read()
448
+ # 这里用 fmt.format,不用 format_map
449
+ fout.write(fmt.format(code_template, **internal_settings))
450
+
451
+ change_mode(operator_script_path, FileCheckConst.DATA_FILE_AUTHORITY)
452
+
453
+ logger.info(f"Generate operator script successfully and the name is {operator_script_path}.")
454
+
455
+
456
+ if __name__ == "__main__":
457
+ parser = argparse.ArgumentParser()
458
+ _op_generator_parser(parser)
459
+ cmd_args = parser.parse_args()
460
+ _run_operator_generate_commond(cmd_args)