mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,9 @@
1
1
  #!/usr/bin/env python3
2
2
  # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
7
  # you may not use this file except in compliance with the License.
7
8
  # You may obtain a copy of the License at
8
9
  #
@@ -13,7 +14,6 @@
13
14
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
15
  # See the License for the specific language governing permissions and
15
16
  # limitations under the License.
16
- """
17
17
 
18
18
  import os
19
19
  import math
@@ -22,19 +22,28 @@ import numpy
22
22
 
23
23
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
24
24
  from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \
25
- CompareException
25
+ CompareException, get_module_and_atttribute_name, get_attribute
26
26
  from msprobe.core.common.file_utils import FileChecker, load_npy
27
27
  from msprobe.pytorch.common.log import logger
28
28
  from msprobe.pytorch.common.utils import load_pt
29
- from msprobe.core.common.const import Const, FileCheckConst
29
+ from msprobe.core.common.const import Const, FileCheckConst, CompareConst
30
30
 
31
31
  TORCH_TYPE = ["torch.device", "torch.dtype"]
32
32
  TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
33
- FLOAT_TYPE = ['torch.float32', 'torch.float', 'torch.float64', 'torch.double', 'torch.float16',
34
- 'torch.half', 'torch.bfloat16']
35
- NUMPY_TYPE = ["numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32",
36
- "numpy.uint64", "numpy.float16", "numpy.float32", "numpy.float64", "numpy.float128", "numpy.complex64",
37
- "numpy.complex128", "numpy.complex256", "numpy.bool_", "numpy.string_", "numpy.bytes_", "numpy.unicode_"]
33
+ FLOAT_TYPE = [
34
+ 'torch.float32',
35
+ 'torch.float',
36
+ 'torch.float64',
37
+ 'torch.double',
38
+ 'torch.float16',
39
+ 'torch.half',
40
+ 'torch.bfloat16'
41
+ ]
42
+ NUMPY_TYPE = [
43
+ "numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32",
44
+ "numpy.uint64", "numpy.float16", "numpy.float32", "numpy.float64", "numpy.float128", "numpy.complex64",
45
+ "numpy.complex128", "numpy.complex256", "numpy.bool_", "numpy.string_", "numpy.bytes_", "numpy.unicode_"
46
+ ]
38
47
 
39
48
 
40
49
  def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
@@ -68,7 +77,8 @@ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
68
77
  raise Exception("{} is not supported now".format(data_type))
69
78
  data = info.get("value")
70
79
  try:
71
- data = eval(data_type)(data)
80
+ module_name, attribute_name = get_module_and_atttribute_name(data_type)
81
+ data = get_attribute(module_name, attribute_name)(data)
72
82
  except Exception as err:
73
83
  logger.error("Failed to convert the type to numpy: %s" % str(err))
74
84
  elif data_type == "torch.Size":
@@ -104,8 +114,9 @@ def gen_real_tensor(data_path, convert_type):
104
114
  if convert_type:
105
115
  ori_dtype = Const.CONVERT.get(convert_type)[0]
106
116
  dist_dtype = Const.CONVERT.get(convert_type)[1]
117
+ module_name, attribute_name = get_module_and_atttribute_name(dist_dtype)
107
118
  if str(data.dtype) == ori_dtype:
108
- data = data.type(eval(dist_dtype))
119
+ data = data.type(get_attribute(module_name, attribute_name))
109
120
  return data
110
121
 
111
122
 
@@ -118,8 +129,12 @@ def gen_random_tensor(info, convert_type):
118
129
  convert_type: convert ori_type to dist_type flag.
119
130
  """
120
131
  check_object_type(info, dict)
121
- low, high = info.get('Min'), info.get('Max')
122
- low_origin, high_origin = info.get('Min_origin'), info.get('Max_origin')
132
+
133
+ low_origin = info.get('Min')
134
+ low = info.get('Min_except_inf_nan', low_origin)
135
+ high_origin = info.get('Max')
136
+ high = info.get('Max_except_inf_nan', high_origin)
137
+
123
138
  low_info = [low, low_origin]
124
139
  high_info = [high, high_origin]
125
140
  data_dtype = info.get('dtype')
@@ -164,33 +179,35 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type):
164
179
  data_dtype = Const.CONVERT.get(convert_type)[1]
165
180
  low, low_origin = low_info[0], low_info[1]
166
181
  high, high_origin = high_info[0], high_info[1]
167
- if data_dtype in FLOAT_TYPE:
182
+ module_name, attribute_name = get_module_and_atttribute_name(data_dtype)
183
+ dtype = get_attribute(module_name, attribute_name)
184
+ if data_dtype in FLOAT_TYPE:
168
185
  if math.isnan(high):
169
- tensor = torch._C._VariableFunctionsClass.full(shape, float('nan'), dtype=eval(data_dtype))
186
+ tensor = torch.full(shape, float('nan'), dtype=dtype)
170
187
  return tensor
171
188
  #high_origin为新版json中的属性,只有当high_origin不为None,且high为inf或-inf时,原tensor全为inf或-inf
172
- if high_origin and high in [float('inf'), float('-inf')]:
173
- tensor = torch._C._VariableFunctionsClass.full(shape, high, dtype=eval(data_dtype))
189
+ if high_origin and high in [float(CompareConst.INF), float(CompareConst.NEG_INF)]:
190
+ tensor = torch.full(shape, high, dtype=dtype)
174
191
  tensor[-1] = low
175
192
  return tensor
176
193
  low_scale, high_scale = low, high
177
- dtype_finfo = torch.finfo(eval(data_dtype))
194
+ dtype_finfo = torch.finfo(dtype)
178
195
  #适配老版json high和low为inf或-inf的情况,取dtype的最大值或最小值进行放缩
179
- if high == float('inf'):
196
+ if high == float(CompareConst.INF):
180
197
  high_scale = dtype_finfo.max
181
- elif high == float('-inf'):
198
+ elif high == float(CompareConst.NEG_INF):
182
199
  high_scale = dtype_finfo.min
183
- if low == float('inf'):
200
+ if low == float(CompareConst.INF):
184
201
  low_scale = dtype_finfo.max
185
- elif low == float('-inf'):
202
+ elif low == float(CompareConst.NEG_INF):
186
203
  low_scale = dtype_finfo.min
187
204
 
188
205
  scale = high_scale - low_scale
189
- rand01 = torch.rand(shape, dtype=eval(data_dtype))
206
+ rand01 = torch.rand(shape, dtype=dtype)
190
207
  tensor = rand01 * scale + low_scale
191
208
  elif 'int' in data_dtype or 'long' in data_dtype:
192
209
  low, high = int(low), int(high)
193
- tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype))
210
+ tensor = torch.randint(low, high + 1, shape, dtype=dtype)
194
211
  else:
195
212
  logger.error('Dtype is not supported: ' + data_dtype)
196
213
  raise NotImplementedError()
@@ -208,9 +225,9 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type):
208
225
  else:
209
226
  tmp_tensor[0] = low
210
227
  tmp_tensor[-1] = high
211
- if high_origin in [float('inf'), float('-inf')]:
228
+ if high_origin in [float(CompareConst.INF), float(CompareConst.NEG_INF)]:
212
229
  tmp_tensor[-1] = high_origin
213
- if low_origin in [float('inf'), float('-inf')]:
230
+ if low_origin in [float(CompareConst.INF), float(CompareConst.NEG_INF)]:
214
231
  tmp_tensor[0] = low_origin
215
232
  data = tmp_tensor.reshape(shape)
216
233
  return data
@@ -233,7 +250,7 @@ def gen_bool_tensor(low, high, shape):
233
250
  return data
234
251
 
235
252
 
236
- def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
253
+ def gen_args(args_info, api_name, func_options):
237
254
  """
238
255
  Function Description:
239
256
  Based on API basic information, generate input parameters: args, for API forward running
@@ -246,9 +263,20 @@ def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_p
246
263
  """
247
264
  check_object_type(args_info, list)
248
265
  args_result = []
266
+
267
+ need_grad = func_options.get('need_grad', True)
268
+ convert_type = func_options.get('convert_type', None)
269
+ real_data_path = func_options.get('real_data_path', None)
270
+ depth = func_options.get('depth', 0)
271
+
272
+ if depth > Const.MAX_DEPTH:
273
+ logger.error("The depth of args is too large, please check the input args.")
274
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
275
+
249
276
  for arg in args_info:
250
277
  if isinstance(arg, (list, tuple)):
251
- data = gen_args(arg, api_name, need_grad, convert_type, real_data_path)
278
+ func_options['depth'] = depth + 1
279
+ data = gen_args(arg, api_name, func_options)
252
280
  elif isinstance(arg, dict):
253
281
  data = gen_data(arg, api_name, need_grad, convert_type, real_data_path)
254
282
  elif arg is None:
@@ -288,7 +316,8 @@ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
288
316
 
289
317
  def gen_torch_kwargs(kwargs_params, key, value):
290
318
  if value.get('type') != "torch.device":
291
- kwargs_params[key] = eval(value.get('value'))
319
+ module_name, attribute_name = get_module_and_atttribute_name(value.get('value'))
320
+ kwargs_params[key] = get_attribute(module_name, attribute_name)
292
321
 
293
322
 
294
323
  def gen_list_kwargs(kwargs_item_value, api_name, convert_type, real_data_path=None):
@@ -327,8 +356,14 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d
327
356
  error_info = f"convert_type params not support {convert_type}."
328
357
  raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
329
358
  kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path)
359
+ func_options = {
360
+ 'need_grad': need_grad,
361
+ 'convert_type': convert_type,
362
+ 'real_data_path': real_data_path,
363
+ 'depth': 0
364
+ }
330
365
  if api_info.get("input_args"):
331
- args_params = gen_args(api_info.get("input_args"), api_name, need_grad, convert_type, real_data_path)
366
+ args_params = gen_args(api_info.get("input_args"), api_name, func_options)
332
367
  else:
333
368
  logger.warning(f'Warning: No args in {api_info} ')
334
369
  args_params = []
@@ -1,3 +1,20 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
1
18
  import subprocess
2
19
  import json
3
20
  import os
@@ -105,7 +122,7 @@ def run_parallel_ut(config):
105
122
  if output == '':
106
123
  break
107
124
  if '[ERROR]' in output:
108
- print(output, end='')
125
+ logger.warning(output, end='')
109
126
  sys.stdout.flush()
110
127
  except ValueError as e:
111
128
  logger.warning(f"An error occurred while reading subprocess output: {e}")
@@ -119,7 +136,8 @@ def run_parallel_ut(config):
119
136
 
120
137
  for api_info in config.api_files:
121
138
  cmd = create_cmd(api_info, next(device_id_cycle))
122
- process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, bufsize=1, shell=False)
139
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
140
+ text=True, bufsize=1, shell=False)
123
141
  processes.append(process)
124
142
  threading.Thread(target=read_process_output, args=(process,), daemon=True).start()
125
143
 
@@ -150,7 +168,8 @@ def run_parallel_ut(config):
150
168
  logger.error(f"An unexpected error occurred: {e}")
151
169
  finally:
152
170
  if progress_bar.n < config.total_items:
153
- logger.warning("The UT task has not been completed. The parameter '-csv_path' along with the path to the result CSV file will be utilized to resume the UT task.")
171
+ logger.warning("The UT task has not been completed. The parameter '-csv_path' along with the path to " \
172
+ "the result CSV file will be utilized to resume the UT task.")
154
173
  clean_up()
155
174
  progress_bar_thread.join()
156
175
  try:
@@ -173,7 +192,8 @@ def prepare_config(args):
173
192
  out_path = out_path_checker.common_check()
174
193
  split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
175
194
  config_path = os.path.realpath(args.config_path) if args.config_path else None
176
- result_csv_path = args.result_csv_path or os.path.join(out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
195
+ result_csv_path = args.result_csv_path or os.path.join(
196
+ out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
177
197
  if not args.result_csv_path:
178
198
  details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv")
179
199
  comparator = Comparator(result_csv_path, details_csv_path, False)
@@ -190,7 +210,8 @@ def prepare_config(args):
190
210
  def main():
191
211
  parser = argparse.ArgumentParser(description='Run UT in parallel')
192
212
  _run_ut_parser(parser)
193
- parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, help='Number of splits for parallel processing. Range: 1-64')
213
+ parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
214
+ help='Number of splits for parallel processing. Range: 1-64')
194
215
  args = parser.parse_args()
195
216
  config = prepare_config(args)
196
217
  run_parallel_ut(config)
@@ -1,3 +1,20 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
1
18
  import argparse
2
19
  import os
3
20
  import sys
@@ -24,8 +41,8 @@ def check_tensor_overflow(x):
24
41
  tensor_max = x.cpu().detach().float().numpy().tolist()
25
42
  tensor_min = tensor_max
26
43
  else:
27
- tensor_max = torch._C._VariableFunctionsClass.max(x).cpu().detach().float().numpy().tolist()
28
- tensor_min = torch._C._VariableFunctionsClass.min(x).cpu().detach().float().numpy().tolist()
44
+ tensor_max = torch.max(x).cpu().detach().float().numpy().tolist()
45
+ tensor_min = torch.min(x).cpu().detach().float().numpy().tolist()
29
46
  # inf
30
47
  if tensor_max == float('inf') or tensor_min == float('-inf'):
31
48
  return True
@@ -1,3 +1,20 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
1
18
  import argparse
2
19
  import os
3
20
  import csv
@@ -17,8 +34,8 @@ else:
17
34
  import torch
18
35
  from tqdm import tqdm
19
36
 
20
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api, UtDataInfo, \
21
- get_validated_result_csv_path, get_validated_details_csv_path, exec_api
37
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import BackwardMessage, UtDataInfo, \
38
+ get_validated_result_csv_path, get_validated_details_csv_path, exec_api, record_skip_info
22
39
  from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
23
40
  from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \
24
41
  initialize_save_path, UtDataProcessor, extract_basic_api_segments, ApiData
@@ -26,13 +43,14 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
26
43
  from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
27
44
  from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
28
45
  from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
29
- from msprobe.core.common.file_utils import FileOpen, FileChecker, \
30
- change_mode, check_path_before_create, create_directory, get_json_contents
46
+ from msprobe.core.common.file_utils import FileChecker, change_mode, check_path_before_create, \
47
+ create_directory, get_json_contents, read_csv
31
48
  from msprobe.pytorch.common.log import logger
32
49
  from msprobe.pytorch.pt_config import parse_json_config
33
50
  from msprobe.core.common.const import Const, FileCheckConst, CompareConst
34
51
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
35
52
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
53
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params
36
54
 
37
55
 
38
56
  current_time = time.strftime("%Y%m%d%H%M%S")
@@ -46,14 +64,7 @@ RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content',
46
64
  OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
47
65
 
48
66
  not_backward_list = ['repeat_interleave']
49
- not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
50
- not_raise_dtype_set = {'type_as'}
51
67
 
52
- RAISE_PRECISION = {
53
- torch.float16: torch.float32,
54
- torch.bfloat16: torch.float32,
55
- torch.float32: torch.float64
56
- }
57
68
 
58
69
  tqdm_params = {
59
70
  'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1
@@ -71,98 +82,6 @@ tqdm_params = {
71
82
  }
72
83
 
73
84
 
74
- def deal_detach(arg, to_detach=True):
75
- return arg.detach() if to_detach else arg
76
-
77
-
78
- def raise_bench_data_dtype(api_name, arg, raise_dtype=None):
79
- '''
80
- 将标杆数据的dtype转换为raise_dtype
81
- 输入:
82
- api_name:api名称
83
- arg:标杆输入
84
- raise_dtype:需要转换的dtype
85
- 输出:
86
- arg: 转换dtype的标杆输入
87
- '''
88
- if api_name in hf_32_standard_api and arg.dtype == torch.float32:
89
- return arg
90
- if raise_dtype is None or arg.dtype not in RAISE_PRECISION or raise_dtype == arg.dtype:
91
- return arg
92
- return arg.type(raise_dtype)
93
-
94
-
95
- def generate_device_params(input_args, input_kwargs, need_backward, api_name):
96
- def recursive_arg_to_device(arg_in, to_detach):
97
- if isinstance(arg_in, (list, tuple)):
98
- return type(arg_in)(recursive_arg_to_device(arg, to_detach) for arg in arg_in)
99
- elif isinstance(arg_in, torch.Tensor):
100
- if need_backward and arg_in.requires_grad:
101
- arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_()
102
- temp_arg_in = arg_in * 1
103
- arg_in = temp_arg_in.type_as(arg_in)
104
- arg_in.retain_grad()
105
- return arg_in
106
- else:
107
- return deal_detach(arg_in.clone(), to_detach).to(current_device)
108
- else:
109
- return arg_in
110
-
111
- is_detach = api_name not in not_detach_set
112
- device_args = recursive_arg_to_device(input_args, is_detach)
113
- device_kwargs = \
114
- {key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()}
115
- return device_args, device_kwargs
116
-
117
-
118
- def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
119
- def recursive_arg_to_cpu(arg_in, to_detach, raise_dtype=None):
120
- if isinstance(arg_in, (list, tuple)):
121
- return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype) for arg in arg_in)
122
- elif isinstance(arg_in, torch.Tensor):
123
- if need_backward and arg_in.requires_grad:
124
- arg_in = deal_detach(raise_bench_data_dtype(
125
- api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
126
- temp_arg_in = arg_in * 1
127
- arg_in = temp_arg_in.type_as(arg_in)
128
- arg_in.retain_grad()
129
- return arg_in
130
- else:
131
- return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach)
132
- else:
133
- return arg_in
134
-
135
- def is_tensor_with_raise_precision(arg_in, check_kwargs=False):
136
- if arg_in.dtype in RAISE_PRECISION:
137
- return True
138
- if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]:
139
- return True
140
- return False
141
-
142
- def recursive_find_dtypes(arg_in, kwargs=None, check_kwargs=False):
143
- if isinstance(arg_in, (list, tuple)):
144
- return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs) for arg in arg_in))
145
- elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
146
- return set([arg_in.dtype])
147
- elif isinstance(arg_in, dict) and check_kwargs:
148
- return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True) for v in arg_in.values()))
149
- return set()
150
-
151
- raise_dtype = None
152
- need_raise_dtypes = recursive_find_dtypes(input_args)
153
- need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
154
- if len(need_raise_dtypes) == 1:
155
- raise_dtype = RAISE_PRECISION.get(need_raise_dtypes.pop(), torch.float32)
156
- elif len(need_raise_dtypes) >= 2:
157
- raise_dtype = torch.float32
158
-
159
- raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
160
- is_detach = api_name not in not_detach_set
161
- cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
162
- cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()}
163
- return cpu_args, cpu_kwargs
164
-
165
-
166
85
  def run_ut(config):
167
86
  logger.info("start UT test")
168
87
  if config.online_config.is_online:
@@ -179,10 +98,8 @@ def run_ut(config):
179
98
  if config.online_config.is_online:
180
99
  run_api_online(config, compare)
181
100
  else:
182
- with FileOpen(config.result_csv_path, 'r') as file:
183
- csv_reader = csv.reader(file)
184
- next(csv_reader)
185
- api_name_set = {row[0] for row in csv_reader}
101
+ csv_df = read_csv(config.result_csv_path)
102
+ api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
186
103
  run_api_offline(config, compare, api_name_set)
187
104
  for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
188
105
  change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
@@ -198,17 +115,23 @@ def run_api_offline(config, compare, api_name_set):
198
115
  if api_full_name in api_name_set:
199
116
  continue
200
117
  if is_unsupported_api(api_full_name):
118
+ skip_message = f"API {api_full_name} not support for run ut. SKIP."
119
+ compare_alg_results = err_column.to_column_value(CompareConst.SKIP, skip_message)
120
+ record_skip_info(api_full_name, compare, compare_alg_results)
201
121
  continue
202
122
  _, api_name = extract_basic_api_segments(api_full_name)
203
123
  if not api_name:
204
124
  err_message = f"API {api_full_name} not support for run ut. SKIP."
205
125
  logger.error(err_message)
206
- fwd_compare_alg_results = err_column.to_column_value(CompareConst.SKIP, err_message)
207
- result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [fwd_compare_alg_results], None, 0)
208
- compare.record_results(result_info)
126
+ compare_alg_results = err_column.to_column_value(CompareConst.SKIP, err_message)
127
+ record_skip_info(api_full_name, compare, compare_alg_results)
209
128
  continue
210
129
  try:
211
130
  if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
131
+ skip_message = f"API {api_name} in black list or not in white list. SKIP."
132
+ logger.info(skip_message)
133
+ compare_alg_results = err_column.to_column_value(CompareConst.SKIP, skip_message)
134
+ record_skip_info(api_full_name, compare, compare_alg_results)
212
135
  continue
213
136
  data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict)
214
137
  is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info)
@@ -220,9 +143,8 @@ def run_api_offline(config, compare, api_name_set):
220
143
  f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
221
144
  else:
222
145
  logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
223
- fwd_compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err))
224
- result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [fwd_compare_alg_results], None, 0)
225
- compare.record_results(result_info)
146
+ compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err))
147
+ record_skip_info(api_full_name, compare, compare_alg_results)
226
148
  finally:
227
149
  if is_gpu:
228
150
  torch.cuda.empty_cache()
@@ -327,12 +249,12 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
327
249
  in_fwd_data_list.append(kwargs)
328
250
  need_backward = api_full_name in backward_content
329
251
  if not need_grad:
330
- logger.warning("%s %s" % (api_full_name, Backward_Message.UNSUPPORT_BACKWARD_MESSAGE))
331
- backward_message += Backward_Message.UNSUPPORT_BACKWARD_MESSAGE
252
+ logger.warning("%s %s" % (api_full_name, BackwardMessage.UNSUPPORT_BACKWARD_MESSAGE))
253
+ backward_message += BackwardMessage.UNSUPPORT_BACKWARD_MESSAGE
332
254
  if api_name in not_backward_list:
333
255
  need_grad = False
334
- logger.warning("%s %s" % (api_full_name, Backward_Message.NO_BACKWARD_RESULT_MESSAGE))
335
- backward_message += Backward_Message.NO_BACKWARD_RESULT_MESSAGE
256
+ logger.warning("%s %s" % (api_full_name, BackwardMessage.NO_BACKWARD_RESULT_MESSAGE))
257
+ backward_message += BackwardMessage.NO_BACKWARD_RESULT_MESSAGE
336
258
  need_backward = need_backward and need_grad
337
259
  if kwargs.get("device"):
338
260
  del kwargs["device"]
@@ -353,13 +275,16 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
353
275
  if need_backward:
354
276
  if need_to_backward(grad_index, out):
355
277
  backward_args = backward_content[api_full_name].get("input")
356
- grad = gen_args(backward_args, api_name, real_data_path=real_data_path)[0]
278
+ func_options = {
279
+ 'real_data_path': real_data_path
280
+ }
281
+ grad = gen_args(backward_args, api_name, func_options)[0]
357
282
  bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
358
283
  bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
359
284
  device_grad = grad.clone().detach().to(current_device)
360
285
  device_grad_out = run_backward(device_args, device_grad, grad_index, device_out)
361
286
  else:
362
- backward_message += Backward_Message.MULTIPLE_BACKWARD_MESSAGE
287
+ backward_message += BackwardMessage.MULTIPLE_BACKWARD_MESSAGE
363
288
  if api_name == "npu_fusion_attention":
364
289
  out = out[0]
365
290
  device_out = device_out[0]
@@ -416,7 +341,7 @@ def initialize_save_error_data(error_data_path):
416
341
  error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
417
342
  ability=FileCheckConst.WRITE_ABLE)
418
343
  error_data_path = error_data_path_checker.common_check()
419
- error_data_path =initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
344
+ error_data_path = initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
420
345
  return error_data_path
421
346
 
422
347
 
@@ -477,7 +402,8 @@ def preprocess_forward_content(forward_content):
477
402
  if key not in arg_cache:
478
403
  filtered_new_args = [
479
404
  {k: v for k, v in arg.items() if k not in ['Max', 'Min']}
480
- for arg in value['input_args'] if isinstance(arg, dict)
405
+ for arg in value['input_args']
406
+ if isinstance(arg, dict)
481
407
  ]
482
408
  arg_cache[key] = (filtered_new_args, value['input_kwargs'])
483
409
 
@@ -529,14 +455,14 @@ def run_ut_command(args):
529
455
  # 离线场景下,forward_content, backward_content, real_data_path从api_info_file中解析
530
456
  forward_content, backward_content, real_data_path = None, None, None
531
457
  if args.api_info_file:
532
- api_info_file_checker = FileChecker(file_path = args.api_info_file, path_type = FileCheckConst.FILE,
533
- ability = FileCheckConst.READ_ABLE, file_type = FileCheckConst.JSON_SUFFIX)
458
+ api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
459
+ ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
534
460
  checked_api_info = api_info_file_checker.common_check()
535
461
  forward_content, backward_content, real_data_path = parse_json_info_forward_backward(checked_api_info)
536
462
  if args.filter_api:
537
- logger.info("Start filtering the api in the forward_input_file.")
463
+ logger.info("Start filtering the api in the api_info_file.")
538
464
  forward_content = preprocess_forward_content(forward_content)
539
- logger.info("Finish filtering the api in the forward_input_file.")
465
+ logger.info("Finish filtering the api in the api_info_file.")
540
466
 
541
467
  out_path = os.path.realpath(args.out_path) if args.out_path else "./"
542
468
  check_path_before_create(out_path)