mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,2081 @@
1
+ import os
2
+ import re
3
+ import stat
4
+ import time
5
+ from enum import Enum, auto
6
+ from abc import ABC, abstractmethod
7
+ import csv
8
+ import random
9
+
10
+ import gc
11
+ import sys
12
+ from pathlib import Path
13
+ import mindspore
14
+ from mindspore import ops
15
+
16
+
17
+ from tabulate import tabulate
18
+
19
+ import logging
20
+
21
+ import traceback
22
+
23
+
24
+
25
+ def error_log_with_exp(self, msg: str, exp: Exception):
26
+ """
27
+ msg: 你的错误提示
28
+ exp: 你要记录的 Exception 实例
29
+ """
30
+ # 将 Exception 的类型、消息和 traceback 通过 exc_info 参数一并传给 .error()
31
+ self.error(msg, exc_info=(type(exp), exp, exp.__traceback__))
32
+
33
+ # 把它挂到 Logger 上
34
+ logging.Logger.error_log_with_exp = error_log_with_exp
35
+
36
+
37
+
38
+ # 1. 基本配置:设置日志级别为 INFO,默认输出到控制台
39
+ logging.basicConfig(level=logging.INFO,
40
+ format='%(asctime)s [%(levelname)s] %(message)s',
41
+ datefmt='%H:%M:%S')
42
+
43
+ logger = logging.getLogger()
44
+
45
+
46
+ # ======= 常数类 =======
47
+
48
+ class CodedException(Exception):
49
+ def __init__(self, code, error_info=''):
50
+ super().__init__()
51
+ self.code = code
52
+ self.error_info = self.err_strs.get(code) + error_info
53
+
54
+ def __str__(self):
55
+ return self.error_info
56
+
57
+
58
+ class ApiAccuracyCheckerException(CodedException):
59
+ ParseJsonFailed = 0
60
+ UnsupportType = 1
61
+ WrongValue = 2
62
+ ApiWrong = 3
63
+ err_strs = {
64
+ ParseJsonFailed: "[msprobe] Api Accuracy Checker parse json failed: ",
65
+ UnsupportType: "[msprobe] Api Accuracy Checker get unsupported type: ",
66
+ WrongValue: "[msprobe] Api Accuracy Checker get wrong value: ",
67
+ ApiWrong: "[msprobe] Api Accuracy Checker something wrong with api: ",
68
+ }
69
+
70
+
71
+ class FileCheckConst:
72
+ """
73
+ Class for file check const
74
+ """
75
+ READ_ABLE = "read"
76
+ WRITE_ABLE = "write"
77
+ READ_WRITE_ABLE = "read and write"
78
+ DIRECTORY_LENGTH = 4096
79
+ FILE_NAME_LENGTH = 255
80
+ FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$"
81
+ FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
82
+ PKL_SUFFIX = ".pkl"
83
+ NUMPY_SUFFIX = ".npy"
84
+ JSON_SUFFIX = ".json"
85
+ PT_SUFFIX = ".pt"
86
+ CSV_SUFFIX = ".csv"
87
+ XLSX_SUFFIX = ".xlsx"
88
+ YAML_SUFFIX = ".yaml"
89
+ IR_SUFFIX = ".ir"
90
+ ZIP_SUFFIX = ".zip"
91
+ SHELL_SUFFIX = ".sh"
92
+ MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
93
+ MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
94
+ MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
95
+ MAX_PT_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
96
+ MAX_CSV_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
97
+ MAX_XLSX_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
98
+ MAX_YAML_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
99
+ MAX_IR_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
100
+ MAX_ZIP_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
101
+ MAX_FILE_IN_ZIP_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
102
+ COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
103
+ DIR = "dir"
104
+ FILE = "file"
105
+ DATA_DIR_AUTHORITY = 0o750
106
+ DATA_FILE_AUTHORITY = 0o640
107
+ FILE_SIZE_DICT = {
108
+ PKL_SUFFIX: MAX_PKL_SIZE,
109
+ NUMPY_SUFFIX: MAX_NUMPY_SIZE,
110
+ JSON_SUFFIX: MAX_JSON_SIZE,
111
+ PT_SUFFIX: MAX_PT_SIZE,
112
+ CSV_SUFFIX: MAX_CSV_SIZE,
113
+ XLSX_SUFFIX: MAX_XLSX_SIZE,
114
+ YAML_SUFFIX: MAX_YAML_SIZE,
115
+ IR_SUFFIX: MAX_IR_SIZE,
116
+ ZIP_SUFFIX: MAX_ZIP_SIZE
117
+ }
118
+ CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]'
119
+
120
+ class Const:
121
+ MAX_DEPTH = 10
122
+ PT_FRAMEWORK = "pytorch"
123
+ MS_FRAMEWORK = "mindspore"
124
+ MT_FRAMEWORK = "mindtorch"
125
+ SEP = "."
126
+ KWARGS = 'kwargs'
127
+ INPUT = 'input'
128
+ OUTPUT = 'output'
129
+ INPUT_ARGS = 'input_args'
130
+ INPUT_KWARGS = 'input_kwargs'
131
+ GRAD_INPUT = 'grad_input'
132
+ GRAD_OUTPUT = 'grad_output'
133
+ BACKWARD = 'backward'
134
+ FORWARD = 'forward'
135
+
136
+
137
+ class CompareConst:
138
+ # compare result data
139
+ PASS = 'pass'
140
+ WARNING = 'Warning'
141
+ ERROR = 'error'
142
+ TRUE = 'TRUE'
143
+ FALSE = 'FALSE'
144
+ SKIP = 'SKIP'
145
+
146
+ # compare result column name
147
+ COSINE = "Cosine"
148
+ EUC_DIST = "EucDist"
149
+ MAX_ABS_ERR = "MaxAbsErr"
150
+ MAX_RELATIVE_ERR = "MaxRelativeErr"
151
+ MIN_RELATIVE_ERR = "MinRelativeErr"
152
+ MEAN_RELATIVE_ERR = "MeanRelativeErr"
153
+ NORM_RELATIVE_ERR = "NormRelativeErr"
154
+
155
+ # accuracy standards
156
+ COS_THRESHOLD = 0.99
157
+ MAX_ABS_ERR_THRESHOLD = 0.001
158
+ MAX_RELATIVE_ERR_THRESHOLD = 0.001
159
+ COS_MAX_THRESHOLD = 0.9
160
+ MAX_ABS_ERR_MAX_THRESHOLD = 1
161
+
162
+ class MsCompareConst:
163
+ # api_info field
164
+ MINT = "Mint"
165
+ MINT_FUNCTIONAL = "MintFunctional"
166
+ TENSOR_API = "Tensor"
167
+ FUNCTIONAL_API = "Functional"
168
+ FUSION_API = "FUSION"
169
+
170
+ API_NAME_STR_LENGTH = 4
171
+ MAX_RECURSION_DEPTH = 20
172
+
173
+ # Mindtorch api_info field
174
+ MINDTORCH_TENSOR = "Tensor"
175
+ MINDTORCH = "Torch"
176
+ MINDTORCH_FUNC = "Functional"
177
+ MINDTORCH_NPU = "NPU"
178
+ MINDTORCH_DIST = "Distributed"
179
+
180
+ MT_VALID_API_TYPES = [
181
+ MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR
182
+ ]
183
+ SUPPORTED_FUSION_LIST = ["flash_attention_score"]
184
+
185
+ TASK_FIELD = "task"
186
+ STATISTICS_TASK = "statistics"
187
+ FRAMEWORK = "framework"
188
+ TENSOR_TASK = "tensor"
189
+ DUMP_DATA_DIR_FIELD = "dump_data_dir"
190
+ DATA_FIELD = "data"
191
+
192
+ # supported api yaml
193
+ SUPPORTED_API_LIST_FILE = "checker_support_api.yaml"
194
+ SUPPORTED_TENSOR_LIST_KEY = "tensor"
195
+
196
+ # detail_csv
197
+ DETAIL_CSV_API_NAME = "API Name"
198
+ DETAIL_CSV_BENCH_DTYPE = "Bench Dtype"
199
+ DETAIL_CSV_TESTED_DTYPE = "Tested Dtype"
200
+ DETAIL_CSV_SHAPE = "Shape"
201
+ DETAIL_CSV_PASS_STATUS = "Status"
202
+ DETAIL_CSV_MESSAGE = "Message"
203
+ DETAIL_CSV_FILE_NAME = "accuracy_checking_details"
204
+
205
+ # result_csv
206
+ RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success"
207
+ RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success"
208
+ RESULT_CSV_FILE_NAME = "accuracy_checking_result"
209
+
210
+ EPSILON = 1e-8
211
+
212
+ class ProcessStatus:
213
+ SUCCESS = "success"
214
+ API_NOT_FOUND = "api_not_found"
215
+ EXCEPTION_SKIP = "exception_skip"
216
+
217
+ # ======= mindtorch支持 ========
218
+
219
+ import torch as mindtorch
220
+ from torch import Tensor as mindtorch_tensor
221
+ import torch.nn.functional as mindtorch_func
222
+ import torch.distributed as mindtorch_dist
223
+
224
+ is_valid_pt_mt_env = True
225
+
226
+
227
+ def is_mindtorch():
228
+ mindtorch_check_result = False
229
+ try:
230
+ import torch as test_torch
231
+ from mindspore import Tensor as MindsporeTensor
232
+ except ImportError:
233
+ return mindtorch_check_result
234
+ tensor = test_torch.tensor(0.0)
235
+ if isinstance(tensor, MindsporeTensor):
236
+ mindtorch_check_result = True
237
+
238
+ return mindtorch_check_result
239
+
240
+
241
+ def remove_torch_related_paths():
242
+ removed_paths = []
243
+ if not is_mindtorch():
244
+ return
245
+ try:
246
+ import torch as remove_torch
247
+ torch_file = remove_torch.__file__
248
+ except ImportError:
249
+ return
250
+
251
+ torch_dir = os.path.dirname(torch_file)
252
+
253
+ torch_dir_path = Path(torch_dir).resolve()
254
+ parent_dir = torch_dir_path.parent
255
+
256
+ paths_to_remove = [str(parent_dir)]
257
+
258
+ for path in paths_to_remove:
259
+ try:
260
+ path_resolved = str(Path(path).resolve())
261
+ except Exception as error:
262
+ logger.debug(f"Failed to resolve path {path}: {error}")
263
+
264
+
265
+ if path_resolved in sys.path:
266
+ index = sys.path.index(path_resolved)
267
+ removed_paths.append((path_resolved, index))
268
+ sys.path.pop(index)
269
+
270
+ return
271
+
272
+
273
+ def clear_torch_from_sys_modules():
274
+ modules_to_remove = []
275
+ for module in sys.modules:
276
+ if module == "torch" or module.startswith("torch."):
277
+ modules_to_remove.append(module)
278
+
279
+ for module in modules_to_remove:
280
+ del sys.modules[module]
281
+
282
+
283
+ def set_pt_mt_env_invalid():
284
+ global is_valid_pt_mt_env
285
+ is_valid_pt_mt_env = False
286
+
287
+
288
+ def delete_torch_paths():
289
+
290
+ if not is_mindtorch():
291
+ set_pt_mt_env_invalid()
292
+
293
+ clear_torch_from_sys_modules()
294
+
295
+ for count_delete_env_path in range(MsCompareConst.MAX_RECURSION_DEPTH):
296
+ if not is_mindtorch():
297
+ break
298
+
299
+ remove_torch_related_paths()
300
+
301
+ clear_torch_from_sys_modules()
302
+
303
+ if count_delete_env_path >= MsCompareConst.MAX_RECURSION_DEPTH - 1:
304
+ raise Exception(f"Please check if you have a valid PyTorch and MindTorch environment, and ensure "
305
+ f"the PYTHONPATH environment variable depth does not exceed {Const.MAX_RECURSION_DEPTH}.")
306
+
307
+
308
+ if not is_mindtorch():
309
+ set_pt_mt_env_invalid()
310
+
311
+ else:
312
+ initial_sys_path = sys.path.copy()
313
+ delete_torch_paths()
314
+
315
+ gc.collect()
316
+
317
+ import torch
318
+
319
+ if is_mindtorch():
320
+ set_pt_mt_env_invalid()
321
+
322
+ sys.path = initial_sys_path
323
+
324
+
325
+
326
+ if not is_valid_pt_mt_env:
327
+ import torch
328
+
329
+
330
+
331
+ # ======= 常数类 =======
332
+
333
+ import numpy as np
334
+ from mindspore._c_expression import typing
335
+ from mindspore.common import dtype as mstype
336
+
337
+
338
+ TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
339
+ TORCH_BOOL_TYPE = ["torch.bool"]
340
+ TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int",
341
+ "torch.int64", "torch.long"]
342
+ TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float",
343
+ "torch.float64", "torch.double"]
344
+ TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"]
345
+ RAISE_PRECISION = {{
346
+ "torch.float16": torch.float32,
347
+ "torch.half": torch.float32,
348
+ "torch.bfloat16": torch.float32,
349
+ "torch.float32": torch.float64,
350
+ "torch.float": torch.float64
351
+ }}
352
+ THOUSANDTH_THRESHOLDING = 0.001
353
+ BACKWARD = 'backward'
354
+ DIR = "dir"
355
+ FILE = "file"
356
+ READ_ABLE = "read"
357
+ WRITE_ABLE = "write"
358
+ READ_WRITE_ABLE = "read and write"
359
+ DIRECTORY_LENGTH = 4096
360
+ FILE_NAME_LENGTH = 255
361
+ SOFT_LINK_ERROR = "检测到软链接"
362
+ FILE_PERMISSION_ERROR = "文件权限错误"
363
+ INVALID_FILE_ERROR = "无效文件"
364
+ ILLEGAL_PATH_ERROR = "非法文件路径"
365
+ ILLEGAL_PARAM_ERROR = "非法打开方式"
366
+ FILE_TOO_LARGE_ERROR = "文件过大"
367
+ FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$"
368
+ FILE_SIZE_DICT = {{
369
+ ".pkl": 1073741824, # 1 * 1024 * 1024 * 1024
370
+ ".npy": 10737418240, # 10 * 1024 * 1024 * 1024
371
+ ".json": 1073741824, # 1 * 1024 * 1024 * 1024
372
+ ".pt": 10737418240, # 10 * 1024 * 1024 * 1024
373
+ ".csv": 1073741824, # 1 * 1024 * 1024 * 1024
374
+ ".xlsx": 1073741824, # 1 * 1024 * 1024 * 1024
375
+ ".yaml": 1073741824, # 1 * 1024 * 1024 * 1024
376
+ ".ir": 1073741824 # 1 * 1024 * 1024 * 1024
377
+ }}
378
+ COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
379
+
380
+
381
+ INT8 = "Int8"
382
+ UINT8 = "UInt8"
383
+ INT16 = "Int16"
384
+ UINT16 = "UInt16"
385
+ INT32 = "Int32"
386
+ UINT32 = "UInt32"
387
+ INT64 = "Int64"
388
+ UINT64 = "UInt64"
389
+ FLOAT16 = "Float16"
390
+ FLOAT32 = "Float32"
391
+ FLOAT64 = "Float64"
392
+ BOOL = "Bool"
393
+ BFLOAT16 = "BFloat16"
394
+ INT4 = "Int4"
395
+
396
+ dtype_str_to_ms_dtype = {
397
+ INT8: mstype.int8,
398
+ UINT8: mstype.uint8,
399
+ INT16: mstype.int16,
400
+ UINT16: mstype.uint16,
401
+ INT32: mstype.int32,
402
+ UINT32: mstype.uint32,
403
+ INT64: mstype.int64,
404
+ UINT64: mstype.uint64,
405
+ FLOAT16: mstype.float16,
406
+ FLOAT32: mstype.float32,
407
+ FLOAT64: mstype.float64,
408
+ BOOL: mstype.bool_,
409
+ BFLOAT16: mstype.bfloat16,
410
+ INT4: mstype.qint4x2
411
+ }
412
+ ms_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_ms_dtype.items()}
413
+
414
+ dtype_str_to_np_dtype = {
415
+ INT8: np.int8,
416
+ UINT8: np.uint8,
417
+ INT16: np.int16,
418
+ UINT16: np.uint16,
419
+ INT32: np.int32,
420
+ UINT32: np.uint32,
421
+ INT64: np.int64,
422
+ UINT64: np.uint64,
423
+ FLOAT16: np.float16,
424
+ FLOAT32: np.float32,
425
+ FLOAT64: np.float64,
426
+ BOOL: np.bool_
427
+ }
428
+ np_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_np_dtype.items()}
429
+
430
+ dtype_str_to_torch_dtype = {
431
+ INT8: torch.int8,
432
+ UINT8: torch.uint8,
433
+ INT16: torch.int16,
434
+ INT32: torch.int32,
435
+ INT64: torch.int64,
436
+ FLOAT16: torch.float16,
437
+ FLOAT32: torch.float32,
438
+ FLOAT64: torch.float64,
439
+ BOOL: torch.bool,
440
+ BFLOAT16: torch.bfloat16,
441
+ }
442
+ torch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_torch_dtype.items()}
443
+
444
+
445
+ dtype_str_to_mindtorch_dtype = {
446
+ INT8: mindtorch.int8,
447
+ UINT8: mindtorch.uint8,
448
+ INT16: mindtorch.int16,
449
+ INT32: mindtorch.int32,
450
+ INT64: mindtorch.int64,
451
+ FLOAT16: mindtorch.float16,
452
+ FLOAT32: mindtorch.float32,
453
+ FLOAT64: mindtorch.float64,
454
+ BOOL: mindtorch.bool,
455
+ BFLOAT16: mindtorch.bfloat16,
456
+ }
457
+ mindtorch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_mindtorch_dtype.items()}
458
+
459
+ MINDSPORE_TENSOR_TYPE_STR = "mindspore.Tensor"
460
+ BOOL_TYPE_STR = "bool"
461
+ INT_TYPE_STR = "int"
462
+ FLOAT_TYPE_STR = "float"
463
+ SLICE_TYPE_STR = "slice"
464
+ TUPLE_TYPE_STR = "tuple"
465
+ STR_TYPE_STR = "str"
466
+ MINDSPORE_DTYPE_TYPE_STR = "mindspore.dtype"
467
+ TORCH_DTYPE_TYPE_STR = "torch.dtype"
468
+
469
+ api_info_type_str_to_type = {
470
+ MINDSPORE_TENSOR_TYPE_STR: mindspore.Tensor,
471
+ BOOL_TYPE_STR: bool,
472
+ INT_TYPE_STR: int,
473
+ FLOAT_TYPE_STR: float,
474
+ SLICE_TYPE_STR: slice,
475
+ STR_TYPE_STR: str,
476
+ MINDSPORE_DTYPE_TYPE_STR: typing.Type,
477
+ }
478
+ type_to_api_info_type_str = {value: key for key, value in api_info_type_str_to_type.items()}
479
+
480
+ DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE = np.float64
481
+ DEFAULT_CONSTRUCT_NP_INT_DTYPE = np.float64
482
+ DEFAULT_CONSTRUCT_NP_UINT_DTYPE = np.float64
483
+
484
+ float_dtype_str_list = [
485
+ FLOAT16,
486
+ FLOAT32,
487
+ FLOAT64,
488
+ BFLOAT16,
489
+ ]
490
+
491
+ int_dtype_str_list = [
492
+ INT8,
493
+ INT16,
494
+ INT32,
495
+ INT64,
496
+ BOOL,
497
+ INT4,
498
+ ]
499
+
500
+ uint_dtype_str_list = [
501
+ UINT8,
502
+ UINT16,
503
+ UINT32,
504
+ UINT64,
505
+ ]
506
+
507
+ # ======= 比对类 =======
508
+
509
+ class CompareResult:
510
+ def __init__(self, compare_value, pass_status, err_msg):
511
+ self.compare_value = compare_value
512
+ self.pass_status = pass_status
513
+ self.err_msg = err_msg
514
+
515
+
516
+ class BaseCompareAlgorithm(ABC):
517
+ def __init__(self) -> None:
518
+ super().__init__()
519
+ self.compare_algorithm_name = None
520
+ self.err_msg_mapping = {
521
+ CompareConst.COSINE: {
522
+ CompareConst.PASS: "",
523
+ CompareConst.ERROR: f"cosine similarity is less than threshold: {CompareConst.COS_THRESHOLD} ",
524
+ CompareConst.SKIP: "two inputs are not valid for computing cosine similarity, skip comparing ",
525
+ },
526
+ CompareConst.MAX_ABS_ERR: {
527
+ CompareConst.PASS: "",
528
+ CompareConst.ERROR: "max absolute difference is greater than " \
529
+ f"threshold: {CompareConst.MAX_ABS_ERR_THRESHOLD} ",
530
+ CompareConst.SKIP: "two inputs are not valid for computing max absolute difference, skip comparing ",
531
+ },
532
+ CompareConst.MAX_RELATIVE_ERR: {
533
+ CompareConst.PASS: "",
534
+ CompareConst.ERROR: "",
535
+ CompareConst.SKIP: "",
536
+ },
537
+ }
538
+
539
+ def __call__(self, bench_compute_element, tested_compute_element):
540
+ '''
541
+ Args:
542
+ bench_compute_element: ComputeElement
543
+ tested_compute_element: ComputeElement
544
+
545
+ Return:
546
+ compare_result: CompareResult
547
+ '''
548
+ if self.check_validity(bench_compute_element, tested_compute_element):
549
+ compare_value = self.run_compare(bench_compute_element, tested_compute_element)
550
+ pass_status = self.check_pass(compare_value)
551
+ else:
552
+ logger.warning(f"not suitable for computing {self.compare_algorithm_name}, skip this.")
553
+ compare_value = None
554
+ pass_status = CompareConst.SKIP
555
+
556
+ err_msg = self.err_msg_mapping.get(self.compare_algorithm_name).get(pass_status)
557
+
558
+ compare_result = CompareResult(compare_value, pass_status, err_msg)
559
+ return compare_result
560
+
561
+ @staticmethod
562
+ def convert_to_np_float64_ndarray(tensor):
563
+ if isinstance(tensor, mindspore.Tensor):
564
+ ndarray = tensor.astype(mindspore.float64).numpy()
565
+ elif isinstance(tensor, torch.Tensor):
566
+ ndarray = tensor.to(torch.float64, copy=True).numpy()
567
+ else:
568
+ err_msg = "BaseCompareAlgorithm.convert_to_np_float64_ndarray failed: " \
569
+ "input is not mindspore.Tensor or torch.Tensor"
570
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
571
+ return ndarray
572
+
573
+ @staticmethod
574
+ def check_two_tensor(bench_compute_element, tested_compute_element):
575
+ bench_parameter = bench_compute_element.get_parameter()
576
+ tested_parameter = tested_compute_element.get_parameter()
577
+
578
+ bench_is_tensor = isinstance(bench_parameter, (mindspore.Tensor, torch.Tensor))
579
+ tested_is_tensor = isinstance(tested_parameter, (mindspore.Tensor, torch.Tensor))
580
+ shape_same = bench_compute_element.get_shape() == tested_compute_element.get_shape()
581
+ return bench_is_tensor and tested_is_tensor and shape_same
582
+
583
+ @abstractmethod
584
+ def check_validity(self, bench_compute_element, tested_compute_element):
585
+ '''
586
+ Args:
587
+ bench_compute_element: ComputeElement
588
+ tested_compute_element: ComputeElement
589
+
590
+ Return:
591
+ check_res: boolean
592
+ '''
593
+ raise NotImplementedError
594
+
595
+ @abstractmethod
596
+ def run_compare(self, bench_compute_element, tested_compute_element):
597
+ '''
598
+ Args:
599
+ bench_compute_element: ComputeElement
600
+ tested_compute_element: ComputeElement
601
+
602
+ Return:
603
+ compare_value: float/int
604
+ '''
605
+ raise NotImplementedError
606
+
607
+ @abstractmethod
608
+ def check_pass(self, compare_value):
609
+ '''
610
+ Args:
611
+ compare_value: float/int
612
+
613
+ Return:
614
+ pass_status: str
615
+ '''
616
+ raise NotImplementedError
617
+
618
+
619
+ class CosineSimilarityCompareAlgorithm(BaseCompareAlgorithm):
620
+ def __init__(self) -> None:
621
+ super().__init__()
622
+ self.compare_algorithm_name = CompareConst.COSINE
623
+
624
+ def check_validity(self, bench_compute_element, tested_compute_element):
625
+ return self.check_two_tensor(bench_compute_element, tested_compute_element)
626
+
627
+ def run_compare(self, bench_compute_element, tested_compute_element):
628
+ bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
629
+ tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
630
+
631
+ bench_norm = np.linalg.norm(bench_ndarray)
632
+ tested_norm = np.linalg.norm(tested_ndarray)
633
+ dot_product = np.dot(bench_ndarray.flatten(), tested_ndarray.flatten())
634
+ cosine_similarity = (MsCompareConst.EPSILON + dot_product) / (MsCompareConst.EPSILON + bench_norm * tested_norm)
635
+ return cosine_similarity
636
+
637
+ def check_pass(self, compare_value):
638
+ if compare_value > CompareConst.COS_THRESHOLD:
639
+ return CompareConst.PASS
640
+ else:
641
+ return CompareConst.ERROR
642
+
643
+
644
+ class MaxAbsoluteDiffCompareAlgorithm(BaseCompareAlgorithm):
645
+ def __init__(self) -> None:
646
+ super().__init__()
647
+ self.compare_algorithm_name = CompareConst.MAX_ABS_ERR
648
+
649
+ def check_validity(self, bench_compute_element, tested_compute_element):
650
+ return self.check_two_tensor(bench_compute_element, tested_compute_element)
651
+
652
+ def run_compare(self, bench_compute_element, tested_compute_element):
653
+ bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
654
+ tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
655
+
656
+ max_absolute_diff = np.max(np.abs(bench_ndarray - tested_ndarray))
657
+ return max_absolute_diff
658
+
659
+ def check_pass(self, compare_value):
660
+ if compare_value < CompareConst.MAX_ABS_ERR_THRESHOLD:
661
+ return CompareConst.PASS
662
+ else:
663
+ return CompareConst.ERROR
664
+
665
+
666
+ class MaxRelativeDiffCompareAlgorithm(BaseCompareAlgorithm):
667
+ def __init__(self) -> None:
668
+ super().__init__()
669
+ self.compare_algorithm_name = CompareConst.MAX_RELATIVE_ERR
670
+
671
+ def check_validity(self, bench_compute_element, tested_compute_element):
672
+ return self.check_two_tensor(bench_compute_element, tested_compute_element)
673
+
674
+ def run_compare(self, bench_compute_element, tested_compute_element):
675
+ bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
676
+ tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
677
+
678
+ abs_diff = np.abs(bench_ndarray - tested_ndarray)
679
+ bench_ndarray_nonzero = np.abs(bench_ndarray) + (bench_ndarray == 0) * MsCompareConst.EPSILON
680
+ max_relative_diff = np.max(abs_diff / bench_ndarray_nonzero)
681
+ return max_relative_diff
682
+
683
+ def check_pass(self, compare_value):
684
+ if compare_value < CompareConst.MAX_RELATIVE_ERR_THRESHOLD:
685
+ return CompareConst.PASS
686
+ else:
687
+ return CompareConst.ERROR
688
+
689
+
690
+ compare_algorithms = {
691
+ CompareConst.COSINE: CosineSimilarityCompareAlgorithm(),
692
+ CompareConst.MAX_ABS_ERR: MaxAbsoluteDiffCompareAlgorithm(),
693
+ CompareConst.MAX_RELATIVE_ERR: MaxRelativeDiffCompareAlgorithm(),
694
+ }
695
+
696
+
697
+
698
+ class CompareStandard(Enum):
699
+ BINARY_EQUALITY_STANDARD = auto()
700
+ ABSOLUTE_THRESHOLD_STANDARD = auto()
701
+ ULP_ERROR_STANDARD = auto()
702
+ BENCHMARK_STANDARD = auto()
703
+ THOUSANDTH_STANDARD = auto()
704
+
705
+
706
+ class CompareStandard(Enum):
707
+ BINARY_EQUALITY_STANDARD = auto()
708
+ ABSOLUTE_THRESHOLD_STANDARD = auto()
709
+ ULP_ERROR_STANDARD = auto()
710
+ BENCHMARK_STANDARD = auto()
711
+ THOUSANDTH_STANDARD = auto()
712
+
713
+
714
+ # ======== 文件操作类 ==========
715
+
716
+ from collections import defaultdict
717
+ from functools import wraps
718
+
719
+
720
+ def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_type=None, accepted_value=None):
721
+ '''
722
+ Args:
723
+ dict_instance: dict, dict parsed from input json
724
+ key: str
725
+ key_description: str
726
+ accepted_type: tuple
727
+ accepted_value: Union[tuple, list]
728
+
729
+ Return:
730
+ value, the corresponding value of "key" in "dict_instance"
731
+
732
+ Exception:
733
+ raise ApiAccuracyCheckerException.ParseJsonFailed error when
734
+ 1. dict_instance is not a dict
735
+ 2. value is None
736
+ 3. value is not accepted type
737
+ 4. value is not accepted value
738
+ '''
739
+ if not isinstance(dict_instance, dict):
740
+ error_info = "check_and_get_from_json_dict failed: input is not a dict"
741
+ raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
742
+ value = dict_instance.get(key)
743
+ if value is None:
744
+ error_info = f"check_and_get_from_json_dict failed: {key_description} is missing"
745
+ raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
746
+ elif accepted_type is not None and not isinstance(value, accepted_type):
747
+ error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted type: {accepted_type}"
748
+ raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
749
+ elif accepted_value is not None and value not in accepted_value:
750
+ error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted value: {accepted_value}"
751
+ raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
752
+ return value
753
+
754
+
755
+ def convert_to_tuple(args):
756
+ if isinstance(args, (tuple, list)):
757
+ return tuple(args)
758
+ else:
759
+ input_list = [args]
760
+ return tuple(input_list)
761
+
762
+
763
+ def trim_output_compute_element_list(compute_element_list, forward_or_backward):
764
+ '''
765
+ Args:
766
+ compute_element_list: List[ComputeElement]
767
+ forward_or_backward: str, Union["forward", "backward"]
768
+ '''
769
+ trimmed_list = []
770
+ for compute_element in compute_element_list:
771
+ if compute_element.get_parameter() is None or \
772
+ (forward_or_backward == Const.BACKWARD and compute_element.get_dtype() not in float_dtype_str_list):
773
+ # trim case: 1. parameter is None. 2. backward output has non float parameter
774
+ continue
775
+ trimmed_list.append(compute_element)
776
+ return trimmed_list
777
+
778
+
779
+
780
+
781
+ # 记录工具函数递归的深度
782
+ recursion_depth = defaultdict(int)
783
+
784
+
785
+ def recursion_depth_decorator(func_info, max_depth=Const.MAX_DEPTH):
786
+ """装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。"""
787
+ def decorator(func):
788
+ @wraps(func)
789
+ def wrapper(*args, **kwargs):
790
+ func_id = id(func)
791
+ recursion_depth[func_id] += 1
792
+
793
+ try:
794
+ result = func(*args, **kwargs)
795
+ finally:
796
+ recursion_depth[func_id] -= 1
797
+ return result
798
+
799
+ return wrapper
800
+
801
+ return decorator
802
+
803
+
804
+
805
+ class FileChecker:
806
+ """
807
+ The class for check file.
808
+
809
+ Attributes:
810
+ file_path: The file or dictionary path to be verified.
811
+ path_type: file or dictionary
812
+ ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability
813
+ file_type(str): The correct file type for file
814
+ """
815
+
816
+ def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True):
817
+ self.file_path = file_path
818
+ self.path_type = self._check_path_type(path_type)
819
+ self.ability = ability
820
+ self.file_type = file_type
821
+ self.is_script = is_script
822
+
823
+ @staticmethod
824
+ def _check_path_type(path_type):
825
+ if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]:
826
+ logger.error(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.')
827
+ raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR)
828
+ return path_type
829
+
830
+ def common_check(self):
831
+ """
832
+ 功能:用户校验基本文件权限:软连接、文件长度、是否存在、读写权限、文件属组、文件特殊字符
833
+ 注意:文件后缀的合法性,非通用操作,可使用其他独立接口实现
834
+ """
835
+ check_path_exists(self.file_path)
836
+ check_link(self.file_path)
837
+ self.file_path = os.path.realpath(self.file_path)
838
+ check_path_length(self.file_path)
839
+ check_path_type(self.file_path, self.path_type)
840
+ self.check_path_ability()
841
+ if self.is_script:
842
+ check_path_owner_consistent(self.file_path)
843
+ check_path_pattern_valid(self.file_path)
844
+ check_common_file_size(self.file_path)
845
+ check_file_suffix(self.file_path, self.file_type)
846
+ if self.path_type == FileCheckConst.FILE:
847
+ check_dirpath_before_read(self.file_path)
848
+ return self.file_path
849
+
850
+ def check_path_ability(self):
851
+ if self.ability == FileCheckConst.WRITE_ABLE:
852
+ check_path_writability(self.file_path)
853
+ if self.ability == FileCheckConst.READ_ABLE:
854
+ check_path_readability(self.file_path)
855
+ if self.ability == FileCheckConst.READ_WRITE_ABLE:
856
+ check_path_readability(self.file_path)
857
+ check_path_writability(self.file_path)
858
+
859
+
860
+ class FileOpen:
861
+ """
862
+ The class for open file by a safe way.
863
+
864
+ Attributes:
865
+ file_path: The file or dictionary path to be opened.
866
+ mode(str): The file open mode
867
+ """
868
+ SUPPORT_READ_MODE = ["r", "rb"]
869
+ SUPPORT_WRITE_MODE = ["w", "wb", "a", "ab"]
870
+ SUPPORT_READ_WRITE_MODE = ["r+", "rb+", "w+", "wb+", "a+", "ab+"]
871
+
872
+ def __init__(self, file_path, mode, encoding='utf-8'):
873
+ self.file_path = file_path
874
+ self.mode = mode
875
+ self.encoding = encoding
876
+ self._handle = None
877
+
878
+ def __enter__(self):
879
+ self.check_file_path()
880
+ binary_mode = "b"
881
+ if binary_mode not in self.mode:
882
+ self._handle = open(self.file_path, self.mode, encoding=self.encoding)
883
+ else:
884
+ self._handle = open(self.file_path, self.mode)
885
+ return self._handle
886
+
887
+ def __exit__(self, exc_type, exc_val, exc_tb):
888
+ if self._handle:
889
+ self._handle.close()
890
+
891
+ def check_file_path(self):
892
+ support_mode = self.SUPPORT_READ_MODE + self.SUPPORT_WRITE_MODE + self.SUPPORT_READ_WRITE_MODE
893
+ if self.mode not in support_mode:
894
+ logger.error("File open not support %s mode" % self.mode)
895
+ raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR)
896
+ check_link(self.file_path)
897
+ self.file_path = os.path.realpath(self.file_path)
898
+ check_path_length(self.file_path)
899
+ self.check_ability_and_owner()
900
+ check_path_pattern_valid(self.file_path)
901
+ if os.path.exists(self.file_path):
902
+ check_common_file_size(self.file_path)
903
+ check_dirpath_before_read(self.file_path)
904
+
905
+ def check_ability_and_owner(self):
906
+ if self.mode in self.SUPPORT_READ_MODE:
907
+ check_path_exists(self.file_path)
908
+ check_path_readability(self.file_path)
909
+ check_path_owner_consistent(self.file_path)
910
+ if self.mode in self.SUPPORT_WRITE_MODE and os.path.exists(self.file_path):
911
+ check_path_writability(self.file_path)
912
+ check_path_owner_consistent(self.file_path)
913
+ if self.mode in self.SUPPORT_READ_WRITE_MODE and os.path.exists(self.file_path):
914
+ check_path_readability(self.file_path)
915
+ check_path_writability(self.file_path)
916
+ check_path_owner_consistent(self.file_path)
917
+
918
+
919
+ def check_link(path):
920
+ abs_path = os.path.abspath(path)
921
+ if os.path.islink(abs_path):
922
+ logger.error('The file path {} is a soft link.'.format(path))
923
+ raise FileCheckException(FileCheckException.SOFT_LINK_ERROR)
924
+
925
+
926
+ def check_path_length(path, name_length=None):
927
+ file_max_name_length = name_length if name_length else FileCheckConst.FILE_NAME_LENGTH
928
+ if len(path) > FileCheckConst.DIRECTORY_LENGTH or \
929
+ len(os.path.basename(path)) > file_max_name_length:
930
+ logger.error('The file path length exceeds limit.')
931
+ raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
932
+
933
+
934
+ def check_path_exists(path):
935
+ if not os.path.exists(path):
936
+ logger.error('The file path %s does not exist.' % path)
937
+ raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
938
+
939
+
940
+ def check_path_readability(path):
941
+ if not os.access(path, os.R_OK):
942
+ logger.error('The file path %s is not readable.' % path)
943
+ raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
944
+
945
+
946
+ def check_path_writability(path):
947
+ if not os.access(path, os.W_OK):
948
+ logger.error('The file path %s is not writable.' % path)
949
+ raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
950
+
951
+
952
+ def check_path_executable(path):
953
+ if not os.access(path, os.X_OK):
954
+ logger.error('The file path %s is not executable.' % path)
955
+ raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
956
+
957
+
958
+ def check_other_user_writable(path):
959
+ st = os.stat(path)
960
+ if st.st_mode & 0o002:
961
+ logger.error('The file path %s may be insecure because other users have write permissions. ' % path)
962
+ raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
963
+
964
+
965
+ def check_path_owner_consistent(path):
966
+ file_owner = os.stat(path).st_uid
967
+ if file_owner != os.getuid() and os.getuid() != 0:
968
+ logger.error('The file path %s may be insecure because is does not belong to you.' % path)
969
+ raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
970
+
971
+
972
+ def check_path_pattern_valid(path):
973
+ if not re.match(FileCheckConst.FILE_VALID_PATTERN, path):
974
+ logger.error('The file path %s contains special characters.' % (path))
975
+ raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
976
+
977
+
978
+ def check_file_size(file_path, max_size):
979
+ try:
980
+ file_size = os.path.getsize(file_path)
981
+ except OSError as os_error:
982
+ logger.error(f'Failed to open "{file_path}". {str(os_error)}')
983
+ raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) from os_error
984
+ if file_size >= max_size:
985
+ logger.error(f'The size ({file_size}) of {file_path} exceeds ({max_size}) bytes, tools not support.')
986
+ raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR)
987
+
988
+
989
+ def check_common_file_size(file_path):
990
+ if os.path.isfile(file_path):
991
+ for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items():
992
+ if file_path.endswith(suffix):
993
+ check_file_size(file_path, max_size)
994
+ return
995
+ check_file_size(file_path, FileCheckConst.COMMOM_FILE_SIZE)
996
+
997
+
998
+ def check_file_suffix(file_path, file_suffix):
999
+ if file_suffix:
1000
+ if not file_path.endswith(file_suffix):
1001
+ logger.error(f"The {file_path} should be a {file_suffix} file!")
1002
+ raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
1003
+
1004
+
1005
+ def check_path_type(file_path, file_type):
1006
+ if file_type == FileCheckConst.FILE:
1007
+ if not os.path.isfile(file_path):
1008
+ logger.error(f"The {file_path} should be a file!")
1009
+ raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
1010
+ if file_type == FileCheckConst.DIR:
1011
+ if not os.path.isdir(file_path):
1012
+ logger.error(f"The {file_path} should be a dictionary!")
1013
+ raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
1014
+
1015
+ def make_dir(dir_path):
1016
+ check_path_before_create(dir_path)
1017
+ dir_path = os.path.realpath(dir_path)
1018
+ if os.path.isdir(dir_path):
1019
+ return
1020
+ try:
1021
+ os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True)
1022
+ except OSError as ex:
1023
+ raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR,
1024
+ f"Failed to create {dir_path}. "
1025
+ f"Please check the path permission or disk space. {str(ex)}") from ex
1026
+ file_check = FileChecker(dir_path, FileCheckConst.DIR)
1027
+ file_check.common_check()
1028
+
1029
+
1030
+
1031
+
1032
+ @recursion_depth_decorator('msprobe.core.common.file_utils.create_directory', max_depth=16)
1033
+ def create_directory(dir_path):
1034
+ """
1035
+ Function Description:
1036
+ creating a safe directory with specified permissions
1037
+ Parameter:
1038
+ dir_path: directory path
1039
+ Exception Description:
1040
+ when invalid data throw exception
1041
+ """
1042
+ check_link(dir_path)
1043
+ check_path_before_create(dir_path)
1044
+ dir_path = os.path.realpath(dir_path)
1045
+ parent_dir = os.path.dirname(dir_path)
1046
+ if not os.path.isdir(parent_dir):
1047
+ create_directory(parent_dir)
1048
+ make_dir(dir_path)
1049
+
1050
+
1051
+ def check_path_before_create(path):
1052
+ check_link(path)
1053
+ if path_len_exceeds_limit(path):
1054
+ raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, 'The file path length exceeds limit.')
1055
+
1056
+ if not re.match(FileCheckConst.FILE_PATTERN, os.path.realpath(path)):
1057
+ raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR,
1058
+ 'The file path {} contains special characters.'.format(path))
1059
+
1060
+
1061
+ def check_dirpath_before_read(path):
1062
+ path = os.path.realpath(path)
1063
+ dirpath = os.path.dirname(path)
1064
+
1065
+
1066
+ def check_file_or_directory_path(path, isdir=False):
1067
+ """
1068
+ Function Description:
1069
+ check whether the path is valid
1070
+ Parameter:
1071
+ path: the path to check
1072
+ isdir: the path is dir or file
1073
+ Exception Description:
1074
+ when invalid data throw exception
1075
+ """
1076
+ if isdir:
1077
+ path_checker = FileChecker(path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE)
1078
+ else:
1079
+ path_checker = FileChecker(path, FileCheckConst.FILE, FileCheckConst.READ_ABLE)
1080
+ path_checker.common_check()
1081
+
1082
+
1083
+ def change_mode(path, mode):
1084
+ if not os.path.exists(path) or os.path.islink(path):
1085
+ return
1086
+ try:
1087
+ os.chmod(path, mode)
1088
+ except PermissionError as ex:
1089
+ raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR,
1090
+ 'Failed to change {} authority. {}'.format(path, str(ex))) from ex
1091
+
1092
+
1093
+ def path_len_exceeds_limit(file_path):
1094
+ return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \
1095
+ len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH
1096
+
1097
+ def load_npy(filepath):
1098
+ check_file_or_directory_path(filepath)
1099
+ try:
1100
+ npy = np.load(filepath, allow_pickle=False)
1101
+ except Exception as e:
1102
+ logger.error(f"The numpy file failed to load. Please check the path: {filepath}.")
1103
+ raise RuntimeError(f"Load numpy file {filepath} failed.") from e
1104
+ return npy
1105
+
1106
+ def write_csv(data, filepath, mode="a+", malicious_check=False):
1107
+ def csv_value_is_valid(value: str) -> bool:
1108
+ if not isinstance(value, str):
1109
+ return True
1110
+ try:
1111
+ # -1.00 or +1.00 should be considered as digit numbers
1112
+ float(value)
1113
+ except ValueError:
1114
+ # otherwise, they will be considered as formular injections
1115
+ return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
1116
+ return True
1117
+
1118
+ if malicious_check:
1119
+ for row in data:
1120
+ for cell in row:
1121
+ if not csv_value_is_valid(cell):
1122
+ raise RuntimeError(f"Malicious value [{cell}] is not allowed "
1123
+ f"to be written into the csv: {filepath}.")
1124
+
1125
+ check_path_before_create(filepath)
1126
+ file_path = os.path.realpath(filepath)
1127
+ try:
1128
+ with FileOpen(filepath, mode, encoding='utf-8-sig') as f:
1129
+ writer = csv.writer(f)
1130
+ writer.writerows(data)
1131
+ except Exception as e:
1132
+ logger.error(f'Save csv file "{os.path.basename(file_path)}" failed')
1133
+ raise RuntimeError(f"Save csv file {file_path} failed.") from e
1134
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
1135
+ print(f"file_path:{file_path}")
1136
+
1137
+
1138
+
1139
+ def write_csv_header(csv_path, header_func):
1140
+ """如果是第一次写入,则写入 CSV 表头"""
1141
+ header = header_func() # 获取表头
1142
+ logger.debug(f"Writing CSV header: {header}")
1143
+ write_csv([header], csv_path, mode="a+")
1144
+
1145
+
1146
+ def get_result_csv_header():
1147
+ """获取结果 CSV 文件的表头"""
1148
+ return [
1149
+ MsCompareConst.DETAIL_CSV_API_NAME,
1150
+ MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
1151
+ MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
1152
+ MsCompareConst.DETAIL_CSV_MESSAGE,
1153
+ ]
1154
+
1155
+
1156
+ def get_detail_csv_header():
1157
+ """获取详细 CSV 文件的表头"""
1158
+ detail_csv_header_basic_info = [
1159
+ MsCompareConst.DETAIL_CSV_API_NAME,
1160
+ MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
1161
+ MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
1162
+ MsCompareConst.DETAIL_CSV_SHAPE,
1163
+ ]
1164
+ detail_csv_header_compare_result = list(compare_algorithms.keys())
1165
+ detail_csv_header_status = [
1166
+ MsCompareConst.DETAIL_CSV_PASS_STATUS,
1167
+ MsCompareConst.DETAIL_CSV_MESSAGE,
1168
+ ]
1169
+ return detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
1170
+
1171
+
1172
+ def check_csv_header(headers, required_constants, csv_path):
1173
+ """校验 CSV 文件表头是否包含所有必需的常量"""
1174
+ missing_constants = [const for const in required_constants if not any(const in header for header in headers)]
1175
+
1176
+ if missing_constants:
1177
+ raise MsprobeBaseException(
1178
+ MsprobeBaseException.MISSING_HEADER_ERROR,
1179
+ f"{csv_path} 缺少以下必需的表头字段: {missing_constants}"
1180
+ )
1181
+ def add_time_as_suffix(name):
1182
+ return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
1183
+
1184
+
1185
+ # ======= 结果落盘管理类 ========
1186
+
1187
+ class DataManager:
1188
+ def __init__(self, csv_dir, result_csv_path):
1189
+ self.results = {}
1190
+ self.results_exception_skip = {}
1191
+ self.is_first_write = True # 标记用于添加表头
1192
+ self.csv_dir = csv_dir
1193
+ self.api_names_set = set() # 存储已经出现的 API 名称的集合
1194
+ # 如果传入了 result_csv_path,则启用断点续检
1195
+ if result_csv_path:
1196
+ self.resume_from_last_csv(result_csv_path)
1197
+ self.initialize_api_names_set(result_csv_path)
1198
+ else:
1199
+ # 默认情况下,设置输出路径为空,等待首次写入时初始化
1200
+ self.result_out_path = os.path.join(self.csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
1201
+ self.detail_out_path = os.path.join(
1202
+ self.csv_dir,
1203
+ os.path.basename(self.result_out_path).replace("result", "details")
1204
+ )
1205
+
1206
+ if self.detail_out_path and os.path.exists(self.detail_out_path):
1207
+ check_file_or_directory_path(self.detail_out_path)
1208
+
1209
+ if self.result_out_path and os.path.exists(self.result_out_path):
1210
+ check_file_or_directory_path(self.result_out_path)
1211
+
1212
+ def initialize_api_names_set(self, result_csv_path):
1213
+ """读取现有的 CSV 文件并存储已经出现的 API 名称到集合中"""
1214
+ # 使用新的 read_csv 函数读取数据
1215
+ csv_data = read_csv(result_csv_path, as_pd=False)
1216
+
1217
+ # 读取标题行
1218
+ headers = csv_data[0] if csv_data else [] # 如果文件为空,则 headers 会为空
1219
+
1220
+ # 使用提取的表头校验函数
1221
+ if check_csv_header(headers, get_result_csv_header(), result_csv_path):
1222
+
1223
+ # 获取 "API Name" 列的索引
1224
+ api_name_index = None
1225
+ for i, header in enumerate(headers):
1226
+ if MsCompareConst.DETAIL_CSV_API_NAME in header: # CSV 文件的标题行包含了字节顺序标记,所以使用通过包含方式来查找
1227
+ api_name_index = i
1228
+ break
1229
+
1230
+ if api_name_index is None:
1231
+ logger.warning(f"{result_csv_path} No column contains 'API Name'.")
1232
+ return
1233
+
1234
+ # 读取每一行的 API 名称
1235
+ for row in csv_data[1:]: # 跳过标题行,从第二行开始
1236
+ if row and len(row) > api_name_index:
1237
+ api_name = row[api_name_index]
1238
+ if api_name:
1239
+ self.api_names_set.add(api_name)
1240
+
1241
+ logger.debug(f"Initialized API names set from existing CSV: {self.api_names_set}")
1242
+
1243
+ def is_unique_api(self, api_name):
1244
+ """检查 API 名称是否唯一,如果已经存在则返回 False,否则加入集合并返回 True"""
1245
+ if api_name in self.api_names_set:
1246
+ return False
1247
+ self.api_names_set.add(api_name)
1248
+ return True
1249
+
1250
+ def resume_from_last_csv(self, result_csv_path):
1251
+ """从上次运行的 result_csv_path 恢复断点"""
1252
+ # 获取上次的目录路径
1253
+ last_dir = os.path.dirname(result_csv_path)
1254
+
1255
+ # 设置当前目录和输出路径,确保在首次写入时使用
1256
+ self.csv_dir = last_dir
1257
+ self.detail_out_path = os.path.join(last_dir, os.path.basename(result_csv_path).replace("result", "details"))
1258
+ if self.detail_out_path and os.path.exists(self.detail_out_path):
1259
+ check_file_or_directory_path(self.detail_out_path)
1260
+ self.result_out_path = result_csv_path
1261
+ self.is_first_write = False
1262
+
1263
+ def save_results(self, api_name_str):
1264
+ if self.is_first_write:
1265
+ # 直接写入表头
1266
+ logger.info("Writing CSV headers for the first time.")
1267
+ write_csv_header(self.detail_out_path, get_detail_csv_header)
1268
+ write_csv_header(self.result_out_path, get_result_csv_header)
1269
+ self.is_first_write = False # 写入后标记为 False,避免重复写入表头
1270
+
1271
+ """写入详细输出和结果摘要并清理结果"""
1272
+ logger.debug("Starting to write detailed output to CSV.")
1273
+ self.to_detail_csv(self.detail_out_path)
1274
+ logger.debug(f"Detailed output for {api_name_str} written to {self.detail_out_path}.")
1275
+
1276
+ logger.debug("Starting to write result summary to CSV.")
1277
+ self.to_result_csv(self.result_out_path)
1278
+ logger.debug(f"Result summary for {api_name_str} written to {self.result_out_path}.")
1279
+
1280
+ # 清理记录,准备下一次调用
1281
+ self.clear_results()
1282
+
1283
+ def record(self, output_list):
1284
+ if output_list is None:
1285
+ return
1286
+ for output in output_list:
1287
+ api_real_name, forward_or_backward, basic_info, compare_result_dict = output
1288
+ key = (api_real_name, forward_or_backward)
1289
+ if key not in self.results:
1290
+ self.results[key] = []
1291
+ self.results[key].append((basic_info, compare_result_dict))
1292
+ logger.debug(f"Complete self.results after recording: {self.results}")
1293
+
1294
+ def record_exception_skip(self, api_name, forward_or_backward, err_msg):
1295
+ '''
1296
+ record exception_skip information into self.record_exception_skip.
1297
+ self.record_exception_skip: dict{str: dict{"forward": str/None, "backward": str/None}}
1298
+ string in key is api_name, string in value is err_msg
1299
+ '''
1300
+ if api_name not in self.results_exception_skip:
1301
+ self.results_exception_skip[api_name] = {Const.FORWARD: None, Const.BACKWARD: None}
1302
+ self.results_exception_skip[api_name][forward_or_backward] = err_msg
1303
+
1304
+ def clear_results(self):
1305
+ """清空 self.results 数据"""
1306
+ logger.debug("Clearing self.results data.")
1307
+ self.results.clear()
1308
+ self.results_exception_skip.clear()
1309
+
1310
+ def to_detail_csv(self, csv_path):
1311
+ logger.debug("Preparing detail CSV headers and rows.")
1312
+ detail_csv = []
1313
+
1314
+ detail_csv_header_compare_result = list(compare_algorithms.keys())
1315
+
1316
+ for _, results in self.results.items():
1317
+ for res in results:
1318
+ basic_info, compare_result_dict = res
1319
+ csv_row_basic_info = [
1320
+ basic_info.api_name,
1321
+ basic_info.bench_dtype,
1322
+ basic_info.tested_dtype,
1323
+ basic_info.shape
1324
+ ]
1325
+ csv_row_compare_result = [
1326
+ compare_result_dict.get(algorithm_name).compare_value
1327
+ for algorithm_name in detail_csv_header_compare_result
1328
+ ]
1329
+ csv_row_status = [basic_info.status, basic_info.err_msg]
1330
+ csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
1331
+ detail_csv.append(csv_row)
1332
+ logger.debug(f"Detail CSV row added: {csv_row}")
1333
+
1334
+ logger.debug(f"Writing detail CSV to {csv_path}.")
1335
+ write_csv(detail_csv, csv_path, mode="a+")
1336
+ logger.debug(f"Detail CSV written successfully to {csv_path}.")
1337
+
1338
+ def to_result_csv(self, csv_path):
1339
+ '''
1340
+ depend on both self.results and self.results_exception_skip
1341
+ '''
1342
+ logger.debug("Preparing result CSV data.")
1343
+ result_csv = []
1344
+
1345
+ result_csv_dict = {}
1346
+ for key, results in self.results.items():
1347
+ api_real_name, forward_or_backward = key
1348
+ pass_status = CompareConst.PASS
1349
+ overall_err_msg = ""
1350
+
1351
+ for res in results:
1352
+ basic_info, _ = res
1353
+ if basic_info.status != CompareConst.PASS:
1354
+ pass_status = CompareConst.ERROR
1355
+ overall_err_msg += basic_info.err_msg
1356
+
1357
+ overall_err_msg = "" if pass_status == CompareConst.PASS else overall_err_msg
1358
+
1359
+ if api_real_name not in result_csv_dict:
1360
+ result_csv_dict[api_real_name] = ResultCsvEntry()
1361
+ if forward_or_backward == Const.FORWARD:
1362
+ result_csv_dict[api_real_name].forward_pass_status = pass_status
1363
+ result_csv_dict[api_real_name].forward_err_msg = overall_err_msg
1364
+ else:
1365
+ result_csv_dict[api_real_name].backward_pass_status = pass_status
1366
+ result_csv_dict[api_real_name].backward_err_msg = overall_err_msg
1367
+
1368
+ for api_name, entry in result_csv_dict.items():
1369
+ overall_err_msg = "" if (entry.forward_pass_status == CompareConst.PASS and
1370
+ entry.backward_pass_status == CompareConst.PASS) else \
1371
+ entry.forward_err_msg + entry.backward_err_msg
1372
+ row = [
1373
+ api_name,
1374
+ entry.forward_pass_status,
1375
+ entry.backward_pass_status,
1376
+ overall_err_msg
1377
+ ]
1378
+ # change row if this api has exception_skip information
1379
+ if api_name in self.results_exception_skip:
1380
+ if self.results_exception_skip[api_name][Const.FORWARD] is not None:
1381
+ row[1] = CompareConst.SKIP
1382
+ row[-1] += self.results_exception_skip[api_name][Const.FORWARD]
1383
+ if self.results_exception_skip[api_name][Const.BACKWARD] is not None:
1384
+ row[2] = CompareConst.SKIP
1385
+ row[-1] += self.results_exception_skip[api_name][Const.BACKWARD]
1386
+ del self.results_exception_skip[api_name]
1387
+ result_csv.append(row)
1388
+ logger.debug(f"Result CSV row added: {row}")
1389
+ for api_name in self.results_exception_skip:
1390
+ current_exception_skip = self.results_exception_skip[api_name]
1391
+ forward_status = None
1392
+ backward_status = None
1393
+ err_msg = ""
1394
+ if current_exception_skip[Const.FORWARD] is not None:
1395
+ forward_status = CompareConst.SKIP
1396
+ err_msg += current_exception_skip[Const.FORWARD]
1397
+ if current_exception_skip[Const.BACKWARD] is not None:
1398
+ backward_status = CompareConst.SKIP
1399
+ err_msg += current_exception_skip[Const.BACKWARD]
1400
+ row = [api_name, forward_status, backward_status, err_msg]
1401
+ result_csv.append(row)
1402
+
1403
+ write_csv(result_csv, csv_path, mode="a+")
1404
+ logger.debug(f"Result CSV written successfully to {csv_path}.")
1405
+
1406
+ # 设置标记为 False,防止后续重复添加表头
1407
+ self.is_first_write = False
1408
+
1409
+ # ======== 全局变量类 =======
1410
+
1411
+ class GlobalContext:
1412
+ def __init__(self):
1413
+ self.is_constructed = True
1414
+ self.dump_data_dir = ""
1415
+ self.framework = Const.MS_FRAMEWORK
1416
+
1417
+ def init(self, is_constructed, dump_data_dir, framework):
1418
+ self.is_constructed = is_constructed
1419
+ self.dump_data_dir = dump_data_dir
1420
+ self.framework = framework
1421
+
1422
+ def get_dump_data_dir(self):
1423
+ return self.dump_data_dir
1424
+
1425
+ def get_is_constructed(self):
1426
+ return self.is_constructed
1427
+
1428
+ def get_framework(self):
1429
+ return self.framework
1430
+
1431
+
1432
+ global_context = GlobalContext()
1433
+
1434
+ # ======== 输入类型类 =======
1435
+
1436
+ def seed_all(seed={random_seed}):
1437
+ random.seed(seed)
1438
+ os.environ['PYTHONHASHSEED'] = str(seed)
1439
+ np.random.seed(seed)
1440
+ torch.manual_seed(seed)
1441
+ torch.use_deterministic_algorithms(True)
1442
+ mindtorch.manual_seed(seed)
1443
+ mindtorch.use_deterministic_algorithms(True)
1444
+ mindspore.set_deterministic(True)
1445
+
1446
+ class ApiInputAggregation:
1447
+ def __init__(self, inputs, kwargs, gradient_inputs) -> None:
1448
+ """
1449
+ Args:
1450
+ inputs: List[ComputeElement]
1451
+ kwargs: dict{str: ComputeElement}
1452
+ gradient_inputs: Union[List[ComputeElement], None]
1453
+ """
1454
+ self.inputs = inputs
1455
+ self.kwargs = kwargs
1456
+ self.gradient_inputs = gradient_inputs
1457
+
1458
+
1459
+ api_parent_module_mapping = {
1460
+ (MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint,
1461
+ (MsCompareConst.MINT, Const.PT_FRAMEWORK): torch,
1462
+ (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
1463
+ (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional,
1464
+ (MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): mindspore.Tensor,
1465
+ (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor,
1466
+ (MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): mindtorch_tensor,
1467
+ (MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): torch.Tensor,
1468
+ (MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): mindtorch,
1469
+ (MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): torch,
1470
+ (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): mindtorch_func,
1471
+ (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): torch.nn.functional,
1472
+ (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): mindtorch_dist,
1473
+ (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed,
1474
+ (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): mindspore.ops
1475
+
1476
+ }
1477
+
1478
+
1479
+ api_parent_module_str_mapping = {
1480
+ (MsCompareConst.MINT, Const.MS_FRAMEWORK): "mindspore.mint",
1481
+ (MsCompareConst.MINT, Const.PT_FRAMEWORK): "torch",
1482
+ (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): "mindspore.mint.nn.functional",
1483
+ (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): "torch.nn.functional",
1484
+ (MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): "mindspore.Tensor",
1485
+ (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor",
1486
+ (MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): "mindtorch_tensor",
1487
+ (MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): "torch.Tensor",
1488
+ (MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): "mindtorch",
1489
+ (MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): "torch",
1490
+ (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): "mindtorch_func",
1491
+ (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): "torch.nn.functional",
1492
+ (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): "mindtorch_dist",
1493
+ (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed",
1494
+ (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): "mindspore.ops"
1495
+ }
1496
+
1497
+
1498
+ class ApiRunner:
1499
+ def __call__(self, api_input_aggregation, api_name_str, forward_or_backward=Const.FORWARD,
1500
+ api_platform=Const.MS_FRAMEWORK):
1501
+ '''
1502
+ Args:
1503
+ api_input_aggregation: ApiInputAggregation
1504
+ api_name_str: str, e.g. "MintFunctional.relu.0"
1505
+ forward_or_backward: str, Union["forward", "backward"]
1506
+ api_platform: str, Union["mindspore", "torch", "mindtorch"]
1507
+
1508
+ Return:
1509
+ outputs: list[ComputeElement]
1510
+
1511
+ Description:
1512
+ run mindspore.mint/torch api
1513
+ '''
1514
+
1515
+ api_type_str, api_sub_name = self.get_info_from_name(api_name_str, api_platform)
1516
+ api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform)
1517
+
1518
+ return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform)
1519
+
1520
+ @staticmethod
1521
+ def get_info_from_name(api_name_str, api_platform=Const.MS_FRAMEWORK):
1522
+ """
1523
+ Args:
1524
+ api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
1525
+ api_platform: str, the platform for the API, which can be either "mindspore" or "mindtorch".
1526
+ It specifies which framework is being used. Default is "mindspore".
1527
+ Return:
1528
+ api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Torch", "Functional"]
1529
+ api_sub_name: str, e.g. "relu"
1530
+ """
1531
+ api_name_list = api_name_str.split(Const.SEP)
1532
+ if len(api_name_list) != 3:
1533
+ err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
1534
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
1535
+ api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
1536
+ if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API,
1537
+ MsCompareConst.FUNCTIONAL_API] \
1538
+ and api_platform == Const.MS_FRAMEWORK:
1539
+ err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api"
1540
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
1541
+
1542
+ if api_type_str not in MsCompareConst.MT_VALID_API_TYPES and api_platform == Const.MT_FRAMEWORK:
1543
+ err_msg = f"ApiRunner.get_info_from_name failed: not torch, functional or Tensor api"
1544
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
1545
+ return api_type_str, api_sub_name
1546
+
1547
+ @staticmethod
1548
+ def get_api_instance(api_type_str, api_sub_name, api_platform):
1549
+ """
1550
+ Args:
1551
+ api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Functional"]
1552
+ api_sub_name: str, e.g. "relu"
1553
+ api_platform: str: Union["mindspore", "pytorch"]
1554
+
1555
+ Return:
1556
+ api_instance: function object
1557
+
1558
+ Description:
1559
+ get mindspore.mint/torch api function
1560
+ mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
1561
+ mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
1562
+ """
1563
+
1564
+ api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
1565
+ api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform))
1566
+ full_api_name = api_parent_module_str + Const.SEP + api_sub_name
1567
+
1568
+ if not hasattr(api_parent_module, api_sub_name):
1569
+ err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
1570
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
1571
+
1572
+ api_instance = getattr(api_parent_module, api_sub_name)
1573
+ if not callable(api_instance):
1574
+ err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not callable"
1575
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
1576
+
1577
+ return api_instance
1578
+
1579
+ @staticmethod
1580
+ def run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform):
1581
+ inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
1582
+ for compute_element in api_input_aggregation.inputs)
1583
+ kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform)
1584
+ for key, value in api_input_aggregation.kwargs.items()}
1585
+ gradient_inputs = api_input_aggregation.gradient_inputs
1586
+
1587
+ if forward_or_backward == Const.FORWARD:
1588
+ forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
1589
+ forward_result_tuple = convert_to_tuple(forward_result)
1590
+ res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
1591
+ if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK:
1592
+ return res_compute_element_list, inputs, kwargs, forward_result_tuple
1593
+ else:
1594
+ if gradient_inputs is None:
1595
+ err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing"
1596
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
1597
+ gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
1598
+ for compute_element in gradient_inputs)
1599
+ if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK:
1600
+ if len(gradient_inputs) == 1:
1601
+ gradient_inputs = gradient_inputs[0]
1602
+
1603
+ def api_with_kwargs(*forward_inputs):
1604
+ return api_instance(*forward_inputs, **kwargs)
1605
+
1606
+ grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs)
1607
+ backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
1608
+ backward_result_tuple = convert_to_tuple(backward_result)
1609
+ res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
1610
+ return res_compute_element_list, gradient_inputs, backward_result_tuple
1611
+ else:
1612
+ # set requires_grad
1613
+ requires_grad_index = []
1614
+ for index, tensor in enumerate(inputs):
1615
+ if isinstance(tensor, torch.Tensor) and \
1616
+ torch_dtype_to_dtype_str.get(tensor.dtype) in float_dtype_str_list:
1617
+ setattr(tensor, "requires_grad", True)
1618
+ requires_grad_index.append(index)
1619
+ forward_results = api_instance(*inputs, **kwargs)
1620
+ forward_results = convert_to_tuple(forward_results)
1621
+ for forward_res, gradient_in in zip(forward_results, gradient_inputs):
1622
+ forward_res.backward(gradient_in)
1623
+ backward_result_list = []
1624
+ for index in requires_grad_index:
1625
+ backward_result_list.append(getattr(inputs[index], "grad"))
1626
+ res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_list]
1627
+
1628
+ return res_compute_element_list
1629
+
1630
+
1631
+ api_runner = ApiRunner()
1632
+
1633
+ # ======== 数据结构类 ========
1634
+
1635
+ class ResultCsvEntry:
1636
+ def __init__(self) -> None:
1637
+ self.forward_pass_status = None
1638
+ self.backward_pass_status = None
1639
+ self.forward_err_msg = ""
1640
+ self.backward_err_msg = ""
1641
+ self.overall_err_msg = None
1642
+
1643
+ class ProcessResultPacket:
1644
+ def __init__(self, process_status, result, err_msg) -> None:
1645
+ self.process_status = process_status
1646
+ self.result = result
1647
+ self.err_msg = err_msg
1648
+
1649
+ class MstensorMetaData:
1650
+ def __init__(self, dtype_str, npy_path, maximum, minimum, shape) -> None:
1651
+ self.dtype_str = dtype_str
1652
+ self.npy_path = npy_path
1653
+ self.maximum = maximum
1654
+ self.minimum = minimum
1655
+ self.shape = shape
1656
+
1657
+
1658
+ class DtypeMetaData:
1659
+ def __init__(self, dtype_str) -> None:
1660
+ self.dtype_str = dtype_str
1661
+
1662
+
1663
+ class ComputeElement:
1664
+ def __init__(self, compute_element_info=None, parameter=None):
1665
+ self.supported_parameter_type = tuple(type_to_api_info_type_str.keys()) + tuple([torch.Tensor, tuple])
1666
+ if parameter is not None:
1667
+ self._init_with_parameter(parameter)
1668
+ elif isinstance(compute_element_info, (list, dict)):
1669
+ self._init_from_compute_element_info(compute_element_info)
1670
+ elif compute_element_info is None:
1671
+ self._init_from_null_compute_element_info()
1672
+ else:
1673
+ pass
1674
+ logger.error_log_with_exp(
1675
+ "ComputeElement.__init__ failed: not init with parameter or compute_element info is not (list, dict)",
1676
+ ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
1677
+
1678
+ @staticmethod
1679
+ def transfer_to_torch_tensor(ms_tensor):
1680
+ '''
1681
+ Args:
1682
+ ms_tensor: mindspore.Tensor
1683
+ Return:
1684
+ torch_tensor: torch.Tensor
1685
+ '''
1686
+ ms_dtype = ms_tensor.dtype
1687
+ dtype_str = ms_dtype_to_dtype_str.get(ms_dtype)
1688
+ if dtype_str not in dtype_str_to_torch_dtype:
1689
+ err_msg = f"ComputeElement.transfer_to_torch_tensor failed: no matching torch dtype for {dtype_str}"
1690
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
1691
+ else:
1692
+ torch_dtype = dtype_str_to_torch_dtype.get(dtype_str)
1693
+
1694
+ if dtype_str in int_dtype_str_list:
1695
+ middle_dtype = mindspore.int64
1696
+ else:
1697
+ middle_dtype = mindspore.float64
1698
+ np_ndarray = ms_tensor.astype(middle_dtype).numpy()
1699
+ torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype)
1700
+ return torch_tensor
1701
+
1702
+ @staticmethod
1703
+ def transfer_to_mindtorch_tensor(ms_tensor):
1704
+ """
1705
+ Args:
1706
+ ms_tensor: mindspore.Tensor
1707
+ Return:
1708
+ mindtorch_tensor: mindtorch.Tensor
1709
+ """
1710
+
1711
+ ms_dtype = ms_tensor.dtype
1712
+
1713
+ dtype_str = ms_dtype_to_dtype_str.get(ms_dtype)
1714
+
1715
+ if dtype_str not in dtype_str_to_mindtorch_dtype:
1716
+ err_msg = f"ComputeElement.transfer_to_mindtorch_tensor failed: no matching mindtorch dtype for {dtype_str}"
1717
+ logger.error_log_with_exp(err_msg,
1718
+ ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
1719
+ else:
1720
+ mindtorch_dtype = dtype_str_to_mindtorch_dtype.get(dtype_str)
1721
+
1722
+ if dtype_str in int_dtype_str_list:
1723
+ middle_dtype = mindspore.int64
1724
+ else:
1725
+ middle_dtype = mindspore.float64
1726
+
1727
+ np_ndarray = ms_tensor.astype(middle_dtype).numpy()
1728
+
1729
+ mindtorch_tensor = mindtorch.from_numpy(np_ndarray).to(ms_dtype)
1730
+
1731
+ return mindtorch_tensor
1732
+
1733
+ @staticmethod
1734
+ def transfer_to_mindspore_tensor(torch_tensor):
1735
+ '''
1736
+ Args:
1737
+ torch_tensor: torch.Tensor
1738
+
1739
+ Return:
1740
+ ms_tensor: mindspore.Tensor
1741
+ '''
1742
+ torch_dtype = torch_tensor.dtype
1743
+ dtype_str = torch_dtype_to_dtype_str.get(torch_dtype)
1744
+ if dtype_str not in dtype_str_to_ms_dtype:
1745
+ err_msg = \
1746
+ f"ComputeElement._transfer_to_mindspore_tensor failed: no matching mindspore dtype for {dtype_str}"
1747
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
1748
+ else:
1749
+ ms_dtype = dtype_str_to_ms_dtype.get(dtype_str)
1750
+
1751
+ if dtype_str in int_dtype_str_list:
1752
+ middle_dtype = torch.int64
1753
+ else:
1754
+ middle_dtype = torch.float64
1755
+ np_ndarray = torch_tensor.to(middle_dtype, copy=True).numpy()
1756
+ ms_tensor = mindspore.Tensor.from_numpy(np_ndarray).astype(ms_dtype)
1757
+ return ms_tensor
1758
+
1759
+ @staticmethod
1760
+ def convert_inf_to_real_num(value, dtype_str):
1761
+ if value == float("inf"):
1762
+ np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
1763
+ value = np.finfo(np_dtype).max
1764
+ elif value == float("-inf"):
1765
+ np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
1766
+ value = np.finfo(np_dtype).min
1767
+ return value
1768
+
1769
+ def get_parameter(self, get_origin=True, tensor_platform=Const.MS_FRAMEWORK):
1770
+ '''
1771
+ Args:
1772
+ get_origin: boolean
1773
+ tensor_platform: str, Union["mindspore", "pytorch"]
1774
+
1775
+ Return:
1776
+ parameter: Union[int, float, str, slice, tuple, torch.Tensor, mindspore.Tensor]
1777
+ '''
1778
+ if self.parameter is None:
1779
+ return self.parameter
1780
+ if isinstance(self.parameter, tuple):
1781
+ return tuple([compute_element.get_parameter(get_origin=get_origin, tensor_platform=tensor_platform)
1782
+ for compute_element in self.parameter])
1783
+ elif isinstance(self.parameter, self.supported_parameter_type):
1784
+ parameter_tmp = self.parameter
1785
+ elif isinstance(self.parameter, DtypeMetaData):
1786
+ if tensor_platform == Const.MS_FRAMEWORK:
1787
+ parameter_tmp = dtype_str_to_ms_dtype.get(self.parameter.dtype_str)
1788
+ elif tensor_platform == Const.PT_FRAMEWORK:
1789
+ parameter_tmp = dtype_str_to_torch_dtype.get(self.parameter.dtype_str)
1790
+ elif tensor_platform == Const.MT_FRAMEWORK:
1791
+ parameter_tmp = dtype_str_to_mindtorch_dtype.get(self.parameter.dtype_str)
1792
+
1793
+ elif isinstance(self.parameter, MstensorMetaData):
1794
+ mstensor_meta_data = self.parameter
1795
+ ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
1796
+ if global_context.get_is_constructed():
1797
+ np_dtype = dtype_str_to_np_dtype.get(mstensor_meta_data.dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
1798
+ ndarray = self._construct_ndarray(mstensor_meta_data.shape, mstensor_meta_data.maximum,
1799
+ mstensor_meta_data.minimum, np_dtype)
1800
+ else:
1801
+ ndarray = load_npy(mstensor_meta_data.npy_path)
1802
+ parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype)
1803
+ else:
1804
+ err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \
1805
+ "(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)"
1806
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
1807
+
1808
+ # if necessary, do transfer
1809
+ if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
1810
+ parameter = self.transfer_to_torch_tensor(parameter_tmp)
1811
+ elif not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.MT_FRAMEWORK:
1812
+ parameter = self.transfer_to_mindtorch_tensor(parameter_tmp)
1813
+ elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform == Const.MS_FRAMEWORK:
1814
+ parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
1815
+ else:
1816
+ parameter = parameter_tmp
1817
+
1818
+ return parameter
1819
+
1820
+ def get_shape(self):
1821
+ return self.shape
1822
+
1823
+ def get_dtype(self):
1824
+ return self.dtype_str
1825
+
1826
+ def _construct_ndarray(self, shape, maximum, minimum, np_dtype):
1827
+ shape = tuple(shape)
1828
+ np.random.seed({random_seed})
1829
+ if np_dtype == np.bool_:
1830
+ ndarray = np.random.rand(*shape) > 0.5
1831
+ else:
1832
+ maximum = self.convert_inf_to_real_num(maximum, np_dtype)
1833
+ minimum = self.convert_inf_to_real_num(minimum, np_dtype)
1834
+ ndarray = np.random.uniform(minimum, maximum, shape).astype(np_dtype)
1835
+ return ndarray
1836
+
1837
+ def _init_from_null_compute_element_info(self):
1838
+ self.parameter = None
1839
+ self.shape = tuple()
1840
+ self.dtype = "None"
1841
+
1842
+ def _init_from_compute_element_info(self, compute_element_info):
1843
+ '''
1844
+ Args:
1845
+ compute_element_info: Union[list, dict]
1846
+
1847
+ Return:
1848
+ void
1849
+
1850
+ init member attributes: self.shape, self.dtype_str, self.parameter
1851
+ '''
1852
+ if isinstance(compute_element_info, list):
1853
+ self.shape = tuple()
1854
+ self.dtype_str = TUPLE_TYPE_STR
1855
+ self.parameter = tuple([ComputeElement(compute_element_info=sub_info)
1856
+ for sub_info in compute_element_info])
1857
+ else:
1858
+ type_str = check_and_get_from_json_dict(compute_element_info, "type", "type field in api_info.json",
1859
+ accepted_type=str, accepted_value=api_info_type_str_to_type.keys())
1860
+ self.shape = tuple()
1861
+ self.dtype_str = type_str
1862
+ if type_str == MINDSPORE_TENSOR_TYPE_STR:
1863
+ self._init_from_mstensor_compute_element_info(compute_element_info)
1864
+ else:
1865
+ value = check_and_get_from_json_dict(compute_element_info, "value", "value field in api_info.json")
1866
+ if type_str == MINDSPORE_DTYPE_TYPE_STR:
1867
+ self.parameter = DtypeMetaData(value)
1868
+ elif type_str == SLICE_TYPE_STR:
1869
+ self.parameter = slice(*tuple(value))
1870
+ else: # type_str in ("str", "int", "float", "bool")
1871
+ self.parameter = value
1872
+
1873
+ def _init_from_mstensor_compute_element_info(self, compute_element_info):
1874
+ '''
1875
+ do not load real tensor, only record meta data
1876
+ '''
1877
+ dtype_str = check_and_get_from_json_dict(compute_element_info, "dtype", "dtype field in api_info.json",
1878
+ accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys())
1879
+ shape = check_and_get_from_json_dict(compute_element_info, "shape", "shape field in api_info.json",
1880
+ accepted_type=(list,))
1881
+ if global_context.get_is_constructed():
1882
+ maximum = check_and_get_from_json_dict(compute_element_info, "Max", "Max field in api_info.json",
1883
+ accepted_type=(int, float))
1884
+ minimum = check_and_get_from_json_dict(compute_element_info, "Min", "Min field in api_info.json",
1885
+ accepted_type=(int, float))
1886
+
1887
+ npy_path = None
1888
+ else:
1889
+ maximum, minimum = None, None
1890
+ data_name = check_and_get_from_json_dict(compute_element_info, "data_name",
1891
+ "data_name field in api_info.json", accepted_type=(str,))
1892
+ npy_path = os.path.join(global_context.get_dump_data_dir(), data_name)
1893
+ mstensor_meta_data = MstensorMetaData(dtype_str, npy_path, maximum, minimum, shape)
1894
+ self.parameter = mstensor_meta_data
1895
+ self.dtype_str = dtype_str
1896
+ self.shape = tuple(shape)
1897
+
1898
+ def _init_with_parameter(self, parameter):
1899
+ self.parameter = parameter
1900
+ print(f"parameter:{parameter}")
1901
+ print(f"self.supported_parameter_type:{self.supported_parameter_type}")
1902
+ if isinstance(parameter, dict):
1903
+ # 这里假设 dict 中有 'type'、'shape'、'dtype' 等字段
1904
+ return self._init_from_compute_element_info(parameter)
1905
+ self.shape = tuple()
1906
+ if not isinstance(parameter, self.supported_parameter_type):
1907
+ err_msg = "ComputeElement._init_with_parameter failed: " \
1908
+ "parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)"
1909
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
1910
+ if isinstance(parameter, mindspore.Tensor):
1911
+ self.shape = tuple(parameter.shape)
1912
+ self.dtype_str = ms_dtype_to_dtype_str.get(parameter.dtype)
1913
+ elif isinstance(parameter, torch.Tensor):
1914
+ self.shape = tuple(parameter.shape)
1915
+ self.dtype_str = torch_dtype_to_dtype_str.get(parameter.dtype)
1916
+ elif isinstance(parameter, typing.Type):
1917
+ self.dtype_str = MINDSPORE_DTYPE_TYPE_STR
1918
+ self.parameter = DtypeMetaData(ms_dtype_to_dtype_str.get(parameter))
1919
+ elif isinstance(parameter, torch.dtype):
1920
+ self.dtype_str = TORCH_DTYPE_TYPE_STR
1921
+ self.parameter = DtypeMetaData(torch_dtype_to_dtype_str.get(parameter))
1922
+ elif isinstance(parameter, tuple):
1923
+ self.dtype_str = TUPLE_TYPE_STR
1924
+ self.parameter = tuple([ComputeElement(parameter=param) for param in parameter])
1925
+ else:
1926
+ self.dtype_str = type_to_api_info_type_str.get(type(parameter))
1927
+ print(f"self.dtype_str{self.dtype_str}")
1928
+
1929
+ class BasicInfoAndStatus:
1930
+ def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
1931
+ self.api_name = api_name
1932
+ self.bench_dtype = bench_dtype
1933
+ self.tested_dtype = tested_dtype
1934
+ self.shape = shape
1935
+ self.status = status
1936
+ self.err_msg = err_msg
1937
+
1938
+ # ======== api执行类 =======
1939
+
1940
+ def get_input(propagation):
1941
+ args_info_forward = {args_info_forward}
1942
+ kwargs_info_forward = {kwargs_info_forward}
1943
+ args_info_backward = {args_info_backward}
1944
+ forward_inputs = [ComputeElement(compute_element_info=compute_element_info)
1945
+ for compute_element_info in args_info_forward]
1946
+ kwargs_compute_element_dict = {
1947
+ key_str: ComputeElement(compute_element_info=compute_element_info)
1948
+ for key_str, compute_element_info in kwargs_info_forward.items()
1949
+ }
1950
+ if args_info_backward:
1951
+ gradient_inputs = [ComputeElement(compute_element_info=compute_element_info)
1952
+ for compute_element_info in args_info_backward]
1953
+ else:
1954
+ gradient_inputs = None
1955
+ return ApiInputAggregation(
1956
+ forward_inputs,
1957
+ kwargs_compute_element_dict,
1958
+ gradient_inputs
1959
+ )
1960
+
1961
+ # 运行和比对函数
1962
+ def run_and_compare_helper(api_name_str, api_input_aggregation, forward_or_backward):
1963
+ """
1964
+ Args:
1965
+ api_info: ApiInfo
1966
+ api_name_str: str
1967
+ api_input_aggregation: ApiInputAggregation
1968
+ forward_or_backward: str: Union["forward", "backward"]
1969
+
1970
+ Return:
1971
+ output_list: List[tuple(str, str, BasicInfoAndStatus, dict{str: CompareResult})]
1972
+
1973
+ Description:
1974
+ get mindspore api output, run torch api and get output.
1975
+ compare output.
1976
+ record compare result.
1977
+ """
1978
+ # get output
1979
+ if forward_or_backward == Const.FORWARD:
1980
+ tested_outputs, inputs, kwargs, forward_result_tuple = api_runner(api_input_aggregation, api_name_str,
1981
+ forward_or_backward,
1982
+ global_context.get_framework())
1983
+ print(f"inputs:{inputs}")
1984
+ print(f"kwargs:{kwargs}")
1985
+ print(f"forward_result_tuple:{forward_result_tuple}")
1986
+ elif forward_or_backward == Const.BACKWARD:
1987
+ tested_outputs, gradient_inputs, backward_result_tuple = api_runner(api_input_aggregation, api_name_str,
1988
+ forward_or_backward,
1989
+ global_context.get_framework())
1990
+ print(f"gradient_inputs:{gradient_inputs}")
1991
+ print(f"backward_result_tuple:{backward_result_tuple}")
1992
+ else:
1993
+ tested_outputs = api_runner(api_input_aggregation, api_name_str,
1994
+ forward_or_backward, global_context.get_framework())
1995
+
1996
+ bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK)
1997
+
1998
+ tested_outputs = trim_output_compute_element_list(tested_outputs, forward_or_backward)
1999
+ bench_outputs = trim_output_compute_element_list(bench_outputs, forward_or_backward)
2000
+
2001
+ # compare output
2002
+ output_list = []
2003
+ for i, (bench_out, tested_out) in enumerate(zip(bench_outputs, tested_outputs)):
2004
+ api_name_with_slot = Const.SEP.join([api_name_str, forward_or_backward, Const.OUTPUT, str(i)])
2005
+ bench_dtype = bench_out.get_dtype()
2006
+ tested_dtype = tested_out.get_dtype()
2007
+ shape = bench_out.get_shape()
2008
+
2009
+ compare_result_dict = dict()
2010
+ for compare_algorithm_name, compare_algorithm in compare_algorithms.items():
2011
+ compare_result = compare_algorithm(bench_out, tested_out)
2012
+ compare_result_dict[compare_algorithm_name] = compare_result
2013
+
2014
+ if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \
2015
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
2016
+ status = CompareConst.PASS
2017
+ err_msg = ""
2018
+ else:
2019
+ status = CompareConst.ERROR
2020
+ err_msg = (compare_result_dict.get(CompareConst.COSINE).err_msg +
2021
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg)
2022
+
2023
+ # self.pre_forward_hook(api_name_str, None, inputs, kwargs)
2024
+ basic_info_status = \
2025
+ BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
2026
+ output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
2027
+ return output_list
2028
+
2029
+
2030
+ if __name__ == "__main__":
2031
+ framework = "{framework}"
2032
+ dump_data_dir = "{real_data_path}"
2033
+ api_name = "{api_name}"
2034
+ api_full_name = "{api_full_name}"
2035
+ api_name_str = ".".join(api_full_name.split(".")[:3])
2036
+ propagation = "{propagation}"
2037
+ data_mode = "{data_mode}"
2038
+ seed_all({random_seed})
2039
+
2040
+ data_manager = DataManager("./op_result_output", None)
2041
+ create_directory("./op_result_output")
2042
+
2043
+ is_constructed = data_mode == "random_data"
2044
+ global_context.init(is_constructed, dump_data_dir, framework)
2045
+
2046
+ for i in range({iter_times}):
2047
+ print(f"iter: {{i}}:")
2048
+ if propagation == BACKWARD:
2049
+
2050
+
2051
+ backward_inputs_aggregation = get_input(propagation)
2052
+
2053
+ backward_output_list = run_and_compare_helper(api_name_str, backward_inputs_aggregation,
2054
+ Const.BACKWARD)
2055
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS,
2056
+ result=backward_output_list, err_msg="")
2057
+
2058
+
2059
+ if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS:
2060
+ data_manager.record(process_result_packet.result)
2061
+ elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP:
2062
+ data_manager.record_exception_skip(api_name_str, Const.BACKWARD, process_result_packet.err_msg)
2063
+
2064
+ data_manager.save_results(api_name_str)
2065
+ else:
2066
+ forward_inputs_aggregation = get_input(propagation)
2067
+
2068
+ forward_output_list = run_and_compare_helper(api_name_str, forward_inputs_aggregation,
2069
+ Const.FORWARD)
2070
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS,
2071
+ result=forward_output_list, err_msg="")
2072
+
2073
+
2074
+ if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS:
2075
+ data_manager.record(process_result_packet.result)
2076
+ elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP:
2077
+ data_manager.record_exception_skip(api_name_str, Const.FORWARD, process_result_packet.err_msg)
2078
+
2079
+ data_manager.save_results(api_name_str)
2080
+
2081
+ print("Compare finished.")