mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -0,0 +1,454 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ import argparse
18
+ import json
19
+ import os
20
+ import re
21
+ import math
22
+ import numpy as np
23
+ import torch
24
+
25
+
26
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import binary_standard_api, absolute_standard_api, ulp_standard_api, thousandth_standard_api
27
+ from msprobe.core.common.file_utils import FileOpen, load_json, save_json
28
+ from msprobe.core.common.utils import check_file_or_directory_path, check_op_str_pattern_valid, is_int
29
+ from msprobe.core.common.const import Const, MonitorConst, MsgConst
30
+ from msprobe.core.common.log import logger
31
+ from msprobe.core.common.file_utils import make_dir
32
+ from msprobe.core.common.utils import recursion_depth_decorator
33
+
34
+ TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
35
+ TORCH_BOOL_TYPE = ["torch.bool"]
36
+ TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int",
37
+ "torch.int64", "torch.long"]
38
+ TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float",
39
+ "torch.float64", "torch.double"]
40
+ TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128",
41
+ "torch.cdouble"]
42
+ OPERATOR_TYPE = ("Functional", "Tensor", "Torch")
43
+
44
+ API_INFO = 2
45
+ FOUR_SEGMENT = 4
46
+ FIVE_SEGMENT = 5
47
+ DATA_NAME = "data_name"
48
+ API_MAX_LENGTH = 30
49
+ PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD]
50
+ DATAMODE_LIST = ["random_data", "real_data"]
51
+
52
+
53
+ class APIInfo:
54
+ def __init__(self, api_full_name, api_info_dict, backward_info=None):
55
+ self.api_full_name = api_full_name
56
+ self.api_info_dict = api_info_dict
57
+ self.backward_info = backward_info
58
+
59
+ @property
60
+ def api_type(self):
61
+ return self.api_full_name.split(Const.SEP, -1)[0]
62
+
63
+ @classmethod
64
+ def from_json(cls, json_content, propagation):
65
+ forward_name, forward_dict = list(json_content.items())[0]
66
+ forward_info = cls(api_full_name=forward_name, api_info_dict=forward_dict)
67
+
68
+ if propagation == Const.BACKWARD:
69
+ backward_name, backward_dict = list(json_content.items())[1]
70
+ backward_info = cls(api_full_name=backward_name, api_info_dict=backward_dict)
71
+ forward_info.backward_info = backward_info
72
+
73
+ if not forward_info.is_supported_type():
74
+ raise ValueError(f"type {forward_info.api_type} of API is not supported!")
75
+
76
+ return forward_info
77
+
78
+ def is_supported_type(self):
79
+ return self.api_type in OPERATOR_TYPE
80
+
81
+ class CommonConfig:
82
+ def __init__(self, json_config):
83
+ self.dump_json_path = json_config.get('dump_json_path')
84
+ self.api_name = json_config.get('api_name')
85
+ self.extract_api_path = json_config.get('extract_api_path')
86
+ self.propagation = json_config.get('propagation')
87
+ self.data_mode = json_config.get('data_mode')
88
+ self.random_seed = json_config.get('random_seed')
89
+ self.iter_times = json_config.get('iter_times')
90
+ self._check_config()
91
+
92
+
93
+ def check_user_settings(self):
94
+ iter_t = self.iter_times
95
+ if iter_t <= 0:
96
+ raise ValueError("iter_times should be an integer bigger than zero!")
97
+
98
+ json_file = self.extract_api_path
99
+ propagation = self.propagation
100
+
101
+ json_content = load_json(json_file)
102
+
103
+ # ensure the dict is not empty
104
+ if not json_content:
105
+ raise ValueError(f'json file is empty!')
106
+
107
+ # ensure json_content is of type dict
108
+ if not isinstance(json_content, dict):
109
+ raise ValueError(f'content of json file is not a dict!')
110
+
111
+ # ensure the length of json_content is within allowed limits
112
+ if len(json_content) > API_INFO:
113
+ raise ValueError(f'json file has more than one API, the API only contains forward and backward info')
114
+
115
+ # Retrieve the first API name and dictionary
116
+ forward_item = next(iter(json_content.items()), None)
117
+ if not forward_item or not isinstance(forward_item[1], dict):
118
+ raise ValueError(f'Invalid forward API data in json_content!')
119
+
120
+ # if propagation is backward, ensure json file contains forward and backward info
121
+ if propagation == Const.BACKWARD and len(json_content) < API_INFO:
122
+ raise ValueError(f'Backward propagation requires contains forward and backward info!')
123
+
124
+ # if propagation is backward, ensure it has valid data
125
+ if propagation == Const.BACKWARD:
126
+ backward_item = list(json_content.items())[1]
127
+ if not isinstance(backward_item[1], dict):
128
+ raise ValueError(f'Invalid backward API data in json_content!')
129
+
130
+ return json_content
131
+
132
+
133
+ def _check_config(self):
134
+ if self.dump_json_path:
135
+ check_file_or_directory_path(self.dump_json_path)
136
+ if self.api_name:
137
+ check_op_str_pattern_valid(self.api_name)
138
+ if len(self.api_name) > API_MAX_LENGTH:
139
+ raise ValueError(f'API name {self.api_name} is too long!')
140
+ make_dir(os.path.dirname(self.extract_api_path))
141
+ if self.propagation and self.propagation not in PROPAGATION_LIST:
142
+ raise ValueError(f'propagation is invalid, it should be one of {PROPAGATION_LIST}')
143
+ if self.data_mode and self.data_mode not in DATAMODE_LIST:
144
+ raise ValueError(f'data_mode is invalid, it should be one of {DATAMODE_LIST}')
145
+ if not is_int(self.random_seed):
146
+ raise ValueError(f'random_seed is invalid, it should be an int')
147
+ if not is_int(self.iter_times):
148
+ raise ValueError(f'iter_times is invalid, it should be an int')
149
+
150
+ class APIExtractor:
151
+ def __init__(self, api_name, dump_json_path, output_file):
152
+ self.api_name = api_name
153
+ self.dump_json_path = dump_json_path
154
+ self.output_file = output_file
155
+ self.data = None
156
+
157
+ def extract_op(self):
158
+ self.data = load_json(self.dump_json_path)
159
+ new_data = {}
160
+ extract_key_pattern = re.compile(f"^{re.escape(self.api_name)}\..+")
161
+ real_data_path = self.data.get('dump_data_dir', '')
162
+ for key, value in self.data.get('data', {}).items():
163
+ if extract_key_pattern.match(key):
164
+ if real_data_path:
165
+ value = self.load_real_data_path(value, real_data_path)
166
+ new_data[key] = value
167
+ if not new_data:
168
+ logger.error(f"Error: The api '{self.api_name}' does not exist in the file.")
169
+ else:
170
+ save_json(self.output_file, new_data, indent=4)
171
+ logger.info(
172
+ f"The api '{self.api_name}' has been successfully extracted and saved in: {self.output_file}")
173
+
174
+ def load_real_data_path(self, value, dump_data_dir):
175
+ parameters = [Const.INPUT_ARGS, Const.GRAD_INPUT, Const.INPUT, Const.OUTPUT, Const.GRAD_OUTPUT]
176
+ for parameter in parameters:
177
+ for v in value.get(parameter, []):
178
+ if v is not None:
179
+ self.update_data_name(v, dump_data_dir)
180
+ return value
181
+
182
+ def update_data_name(self, data, dump_data_dir):
183
+ if isinstance(data, list):
184
+ for item in data:
185
+ self.update_data_name(item, dump_data_dir)
186
+ elif DATA_NAME in data:
187
+ data[DATA_NAME] = os.path.join(dump_data_dir, data[DATA_NAME])
188
+
189
+ class OperatorScriptGenerator:
190
+ def __init__(self, common_config, args_info_forward, kwargs_info_forward, args_info_backward):
191
+ self.common_config = common_config
192
+ self.args_info_forward = args_info_forward
193
+ self.kwargs_info_forward = kwargs_info_forward
194
+ self.args_info_backward = args_info_backward
195
+
196
+ @staticmethod
197
+ def get_compare_standard(api_name):
198
+ api_standard_map = {
199
+ "binary_standard_api": "CompareStandard.BINARY_EQUALITY_STANDARD",
200
+ "absolute_standard_api": "CompareStandard.ABSOLUTE_THRESHOLD_STANDARD",
201
+ "ulp_standard_api": "CompareStandard.ULP_ERROR_STANDARD",
202
+ "thousandth_standard_api": "CompareStandard.THOUSANDTH_STANDARD"
203
+ }
204
+ for standard_api, standard_value in api_standard_map.items():
205
+ if api_name in globals()[standard_api]:
206
+ return standard_value
207
+ return "CompareStandard.BENCHMARK_STANDARD"
208
+
209
+ @staticmethod
210
+ def extract_detailed_api_segments(full_api_name):
211
+ """
212
+ Function Description:
213
+ Extract the name of the API.
214
+ Parameter:
215
+ full_api_name_with_direction_status: Full name of the API. Example: torch.matmul.0.forward.output.0
216
+ Return:
217
+ api_name: Name of api. Example: matmul, mul, etc.
218
+ full_api_name: Full name of api. Example: torch.matmul.0
219
+ direction_status: Direction status of api. Example: forward, backward, etc.
220
+ """
221
+ api_parts = full_api_name.split(Const.SEP)
222
+ api_parts_length = len(api_parts)
223
+ api_type, api_name, api_order = None, None, None
224
+ if api_parts_length == FOUR_SEGMENT:
225
+ api_type, api_name, api_order, _ = api_parts
226
+ elif api_parts_length == FIVE_SEGMENT:
227
+ api_type, prefix, api_name, api_order, _ = api_parts
228
+ api_name = Const.SEP.join([prefix, api_name])
229
+ return api_type, api_name, api_order
230
+
231
+ def get_settings(self, api_full_name):
232
+ '''
233
+ internal_settings contain all information needed for the operator program.
234
+ keys:
235
+ api_full_name: api_type.api_name.ordinal_number
236
+ api_type: type of API, one of torch.nn.functional, torch.Tensor or Torch
237
+ api_name: name of API
238
+ ordinal_number: how many times the same api has been called
239
+ direction_status: forward
240
+ random_seed: if mode is random_data, random seed is random_seed
241
+ iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data, iter_times does not matter
242
+ args_element_assignment: code for args assignment
243
+ args_list_generator_device: code for generate args list on device
244
+ args_list_generator_bench: code for generate args list on bench
245
+ kwargs_value_assignment: code for kwargs assignment
246
+ kwargs_dict_generator_device: code for generate kwargs dict on device
247
+ kwargs_dict_generator_bench: code for generate kwargs dict on bench
248
+ '''
249
+ # Generate an internal setting dictionary based on user settings
250
+ # including API name, type, comparison standard, random seed, number of iterations and other information
251
+ internal_settings = {}
252
+ internal_settings["propagation"] = self.common_config.propagation
253
+ internal_settings["api_full_name"] = api_full_name
254
+ api_type, api_name, ordinal_number = self.extract_detailed_api_segments(api_full_name)
255
+ if api_type == "Functional":
256
+ internal_settings["api_type"] = "torch.nn.functional"
257
+ elif api_type == "Tensor":
258
+ internal_settings["api_type"] = "torch.Tensor"
259
+ else:
260
+ internal_settings["api_type"] = "torch"
261
+ internal_settings["api_name"] = api_name
262
+ internal_settings["compare_standard"] = self.get_compare_standard(api_name)
263
+ internal_settings["ordinal_number"] = ordinal_number
264
+ internal_settings["direction_status"] = self.common_config.propagation
265
+ internal_settings["random_seed"] = self.common_config.random_seed
266
+ if self.common_config.data_mode == "real_data":
267
+ internal_settings["iter_times"] = 1
268
+ else:
269
+ internal_settings["iter_times"] = self.common_config.iter_times
270
+ internal_settings["args_element_assignment"] = self.generate_args_element_assignment_code(self.args_info_forward)
271
+ internal_settings["args_list_generator_device"] = self.generate_args_list(self.args_info_forward, flag_device=True)
272
+ internal_settings["args_list_generator_bench"] = self.generate_args_list(self.args_info_forward, flag_device=False)
273
+ internal_settings["kwargs_value_assignment"] = self.generate_kwargs_value_assignment_code(self.kwargs_info_forward)
274
+ internal_settings["kwargs_dict_generator_device"] = self.generate_kwargs_dict(self.kwargs_info_forward, flag_device=True)
275
+ internal_settings["kwargs_dict_generator_bench"] = self.generate_kwargs_dict(self.kwargs_info_forward, flag_device=False)
276
+ if self.common_config.propagation == Const.BACKWARD:
277
+ internal_settings["args_element_assignment_backward"] = self.generate_args_element_assignment_code(
278
+ self.args_info_backward)
279
+ internal_settings["args_list_generator_device_backward"] = self.generate_args_list(self.args_info_backward, flag_device=True)
280
+ internal_settings["args_list_generator_bench_backward"] = self.generate_args_list(self.args_info_backward, flag_device=False)
281
+ else:
282
+ internal_settings["args_element_assignment_backward"] = ''
283
+ internal_settings["args_list_generator_device_backward"] = ''
284
+ internal_settings["args_list_generator_bench_backward"] = ''
285
+
286
+ return internal_settings
287
+
288
+ @recursion_depth_decorator("OpGenerator: OperatorScriptGenerator.recursive_args_element_assignment")
289
+ def recursive_args_element_assignment(self, args_info, name_number):
290
+ args_element_assignment = ""
291
+ for index, arg in enumerate(args_info):
292
+ if isinstance(arg, (list, tuple)):
293
+ new_args_element_assignment = self.recursive_args_element_assignment(arg, name_number + "_" + str(index))
294
+ args_element_assignment += new_args_element_assignment
295
+ else:
296
+ arg["parameter_name"] = "arg" + name_number + "_" + str(index)
297
+ args_element_assignment += " " + "arg_info" + name_number + "_" + str(index) + " = " + "{}".format(str(arg)) + MsgConst.SPECIAL_CHAR[0]
298
+ args_element_assignment += " " + "arg" + name_number + "_" + str(index) + " = " + "generate_data(arg_info" + name_number + "_" + str(index) + ")" + MsgConst.SPECIAL_CHAR[0]
299
+ return args_element_assignment
300
+
301
+
302
+ def generate_args_element_assignment_code(self, args_info):
303
+ args_element_assignment = self.recursive_args_element_assignment(args_info, "")
304
+ return args_element_assignment
305
+
306
+ @recursion_depth_decorator("OpGenerator: OperatorScriptGenerator.recursive_args_list")
307
+ def recursive_args_list(self, args_info, flag_device=False, flag_bench=False):
308
+ args_list_generator = ""
309
+ for _, arg in enumerate(args_info):
310
+ if isinstance(arg, (list, tuple)):
311
+ (left_bracket, right_bracket) = ("[", "]") if isinstance(arg, list) else ("(", ")")
312
+ args_list_generator += left_bracket
313
+ new_args_list_generator = self.recursive_args_list(arg, flag_device=flag_device, flag_bench=flag_bench)
314
+ args_list_generator += new_args_list_generator
315
+ args_list_generator += right_bracket
316
+ else:
317
+ args_list_generator += arg.get("parameter_name")
318
+ if arg.get("type") in TENSOR_DATA_LIST:
319
+ if flag_device:
320
+ args_list_generator += ".to(device)"
321
+ if flag_bench:
322
+ args_list_generator += '.to(torch.device("cpu"))'
323
+ args_list_generator += ".to(RAISE_PRECISION.get(str(" + arg.get("parameter_name") + ".dtype), " + arg.get("parameter_name") + ".dtype))"
324
+ args_list_generator += Const.COMMA
325
+ return args_list_generator
326
+
327
+ def generate_args_list(self, args_info, flag_device):
328
+ if flag_device:
329
+ args_list_generator = self.recursive_args_list(args_info, flag_device=True)
330
+ else:
331
+ args_list_generator = self.recursive_args_list(args_info, flag_bench=True)
332
+ return args_list_generator
333
+
334
+ @recursion_depth_decorator("OpGenerator: OperatorScriptGenerator.recursive_kwargs_value_assignment")
335
+ def recursive_kwargs_value_assignment(self, info, key_name, name_number):
336
+ kwargs_value_assignment = ""
337
+ if isinstance(info, dict):
338
+ if info.get("type") == "torch.device" or info.get("type") == "torch.dtype":
339
+ kwargs_value_assignment += " " + "kwarg_" + key_name + name_number + " = " + info.get("value")
340
+ else:
341
+ kwargs_value_assignment += " " + "kwarg_info_" + key_name + name_number + " = " + "{}".format(str(info)) + MsgConst.SPECIAL_CHAR[0]
342
+ kwargs_value_assignment += " " + "kwarg_" + key_name + name_number + " = " + "generate_data(kwarg_info_" + key_name + name_number + ")" + MsgConst.SPECIAL_CHAR[0]
343
+ info["parameter_name"] = "kwarg_" + key_name + name_number
344
+ else:
345
+ for index, arg in enumerate(info):
346
+ new_kwargs_value_assignment = self.recursive_kwargs_value_assignment(arg, key_name, name_number + "_" + str(index))
347
+ kwargs_value_assignment += new_kwargs_value_assignment
348
+ return kwargs_value_assignment
349
+
350
+ def generate_kwargs_value_assignment_code(self, kwargs_info):
351
+ kwargs_value_assignment = ""
352
+ for key, value in kwargs_info.items():
353
+ kwargs_value_assignment += self.recursive_kwargs_value_assignment(value, key, "")
354
+ return kwargs_value_assignment
355
+
356
+ @recursion_depth_decorator("OpGenerator: OperatorScriptGenerator.recursive_kwargs_dict")
357
+ def recursive_kwargs_dict(self, info, flag_device=False, flag_bench=False):
358
+ kwargs_dict_generator = ""
359
+ if isinstance(info, dict):
360
+ kwargs_dict_generator += info.get("parameter_name")
361
+ if info.get("type") in TENSOR_DATA_LIST:
362
+ if flag_device:
363
+ kwargs_dict_generator += ".to(device)"
364
+ if flag_bench:
365
+ kwargs_dict_generator += '.to(torch.device("cpu"))'
366
+ kwargs_dict_generator += ".to(RAISE_PRECISION.get(str(" + info.get("parameter_name") + ".dtype), " + info.get("parameter_name") + ".dtype))"
367
+ else:
368
+ (left_bracket, right_bracket) = ("[", "]") if isinstance(info, list) else ("(", ")")
369
+ kwargs_dict_generator += left_bracket
370
+ for arg in info:
371
+ kwargs_dict_generator += self.recursive_kwargs_dict(arg, flag_device=flag_device, flag_bench=flag_bench)
372
+ kwargs_dict_generator += Const.COMMA
373
+ kwargs_dict_generator += right_bracket
374
+ return kwargs_dict_generator
375
+
376
+
377
+ def generate_kwargs_dict(self, kwargs_info, flag_device):
378
+ kwargs_dict_generator = ""
379
+ for key, value in kwargs_info.items():
380
+ kwargs_dict_generator += '"' + key + '"' + MonitorConst.VPP_SEP
381
+ if flag_device:
382
+ kwargs_dict_generator += self.recursive_kwargs_dict(value, flag_device=True) + Const.COMMA
383
+ else:
384
+ kwargs_dict_generator += self.recursive_kwargs_dict(value, flag_bench=True) + Const.COMMA
385
+ return kwargs_dict_generator
386
+
387
+
388
+
389
+ def op_generator_parser(parser):
390
+ parser.add_argument("-i", "--config_input", dest="config_input", default='', type=str,
391
+ help="<Optional> Path of config json file", required=True)
392
+ parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str,
393
+ help="<Required> Path of extract api_name.json.",
394
+ required=True)
395
+
396
+ def parse_json_config(json_file_path):
397
+ if not json_file_path:
398
+ config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
399
+ json_file_path = os.path.join(config_dir, "config.json")
400
+ json_config = load_json(json_file_path)
401
+ common_config = CommonConfig(json_config)
402
+ return common_config
403
+
404
+ def main():
405
+ parser = argparse.ArgumentParser()
406
+ op_generator_parser(parser)
407
+ cmd_args = parser.parse_args()
408
+
409
+ common_config = parse_json_config(cmd_args.config_input)
410
+
411
+ if common_config.dump_json_path:
412
+ api_extract = APIExtractor(common_config.api_name, common_config.dump_json_path, common_config.extract_api_path)
413
+ api_extract.extract_op()
414
+ check_file_or_directory_path(common_config.extract_api_path)
415
+ check_file_or_directory_path(cmd_args.api_output_path, isdir=True)
416
+ json_content = common_config.check_user_settings()
417
+ api_info = APIInfo.from_json(json_content, common_config.propagation)
418
+
419
+ if common_config.propagation == Const.BACKWARD:
420
+ # read and check json
421
+ api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict
422
+ api_full_name_backward, api_info_dict_backward = (api_info.backward_info.api_full_name,
423
+ api_info.backward_info.api_info_dict)
424
+ args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS)
425
+ kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS)
426
+ if Const.GRAD_INPUT in api_info_dict_backward:
427
+ args_info_backward = api_info_dict_backward.get(Const.GRAD_INPUT)
428
+ elif Const.INPUT in api_info_dict_backward:
429
+ args_info_backward = api_info_dict_backward.get(Const.INPUT)
430
+ op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, args_info_backward)
431
+ internal_settings = op_generate.get_settings(api_full_name_backward)
432
+ else:
433
+ # read and check json
434
+ api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict
435
+ args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS)
436
+ kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS)
437
+ op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, None)
438
+ internal_settings = op_generate.get_settings(api_full_name_forward)
439
+
440
+ template_path = os.path.join(os.path.dirname(__file__), "operator_replication.template")
441
+ operator_script_path = os.path.join(cmd_args.api_output_path, "{0}.py".format(internal_settings.get("api_full_name")))
442
+
443
+ try:
444
+ with FileOpen(template_path, 'r') as ftemp, FileOpen(operator_script_path, 'w') as fout:
445
+ code_template = ftemp.read()
446
+ fout.write(code_template.format(**internal_settings))
447
+ except OSError:
448
+ logger.error(f"Failed to open file. Please check file {template_path} or {operator_script_path}.")
449
+
450
+ logger.info(f"Generate operator script successfully and the name is {operator_script_path}.")
451
+
452
+
453
+ if __name__ == "__main__":
454
+ main()