mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +14 -19
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +155 -6
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +3 -0
  10. msprobe/core/common/utils.py +28 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +380 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/multiprocessing_compute.py +2 -2
  22. msprobe/core/compare/npy_compare.py +109 -147
  23. msprobe/core/compare/utils.py +189 -69
  24. msprobe/core/data_dump/data_collector.py +51 -21
  25. msprobe/core/data_dump/data_processor/base.py +38 -20
  26. msprobe/core/data_dump/data_processor/factory.py +5 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
  29. msprobe/core/data_dump/json_writer.py +29 -1
  30. msprobe/core/data_dump/scope.py +19 -18
  31. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  32. msprobe/core/overflow_check/checker.py +1 -1
  33. msprobe/core/overflow_check/utils.py +1 -1
  34. msprobe/docs/01.installation.md +96 -17
  35. msprobe/docs/02.config_introduction.md +5 -5
  36. msprobe/docs/05.data_dump_PyTorch.md +91 -61
  37. msprobe/docs/06.data_dump_MindSpore.md +57 -19
  38. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  39. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
  40. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  41. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  42. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  43. msprobe/docs/19.monitor.md +120 -27
  44. msprobe/docs/21.visualization_PyTorch.md +115 -35
  45. msprobe/docs/22.visualization_MindSpore.md +138 -41
  46. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  47. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  48. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  49. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  50. msprobe/docs/27.dump_json_instruction.md +521 -0
  51. msprobe/docs/FAQ.md +26 -2
  52. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  53. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  54. msprobe/docs/img/merge_result.png +0 -0
  55. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  56. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  57. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  58. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  59. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  60. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  61. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  63. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  64. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  65. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  66. msprobe/docs/visualization/GPTModel.png +0 -0
  67. msprobe/docs/visualization/ParallelMLP.png +0 -0
  68. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  69. msprobe/docs/visualization/mapping.png +0 -0
  70. msprobe/docs/visualization/mapping1.png +0 -0
  71. msprobe/docs/visualization/module_name.png +0 -0
  72. msprobe/docs/visualization/module_name1.png +0 -0
  73. msprobe/docs/visualization/no_mapping.png +0 -0
  74. msprobe/docs/visualization/no_mapping1.png +0 -0
  75. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  76. msprobe/docs/visualization/top_layer.png +0 -0
  77. msprobe/mindspore/__init__.py +10 -0
  78. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
  79. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  80. msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
  81. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  82. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  83. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  84. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  85. msprobe/mindspore/code_mapping/bind.py +264 -0
  86. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  87. msprobe/mindspore/code_mapping/graph.py +49 -0
  88. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  89. msprobe/mindspore/code_mapping/main.py +24 -0
  90. msprobe/mindspore/code_mapping/processor.py +34 -0
  91. msprobe/mindspore/common/const.py +3 -1
  92. msprobe/mindspore/common/utils.py +50 -5
  93. msprobe/mindspore/compare/distributed_compare.py +0 -2
  94. msprobe/mindspore/compare/ms_compare.py +105 -63
  95. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  96. msprobe/mindspore/debugger/debugger_config.py +3 -0
  97. msprobe/mindspore/debugger/precision_debugger.py +81 -12
  98. msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
  99. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  100. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  101. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  102. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  103. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  104. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  105. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  106. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  107. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  108. msprobe/mindspore/grad_probe/hook.py +13 -4
  109. msprobe/mindspore/mindtorch/__init__.py +18 -0
  110. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  111. msprobe/mindspore/ms_config.py +5 -1
  112. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  113. msprobe/mindspore/service.py +267 -101
  114. msprobe/msprobe.py +24 -3
  115. msprobe/pytorch/__init__.py +7 -6
  116. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  117. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  123. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  124. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
  125. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  126. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  127. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  128. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  129. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  130. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  131. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  132. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  133. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  134. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  135. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  136. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  140. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  141. msprobe/pytorch/common/parse_json.py +2 -1
  142. msprobe/pytorch/common/utils.py +45 -2
  143. msprobe/pytorch/compare/distributed_compare.py +17 -29
  144. msprobe/pytorch/compare/pt_compare.py +40 -20
  145. msprobe/pytorch/debugger/debugger_config.py +27 -12
  146. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  147. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  148. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  149. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
  150. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  151. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  152. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  153. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  154. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  155. msprobe/pytorch/hook_module/__init__.py +1 -1
  156. msprobe/pytorch/hook_module/hook_module.py +14 -11
  157. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  158. msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
  159. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  160. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  161. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  162. msprobe/pytorch/monitor/anomaly_detect.py +107 -22
  163. msprobe/pytorch/monitor/csv2tb.py +166 -0
  164. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  165. msprobe/pytorch/monitor/features.py +3 -3
  166. msprobe/pytorch/monitor/module_hook.py +483 -277
  167. msprobe/pytorch/monitor/module_metric.py +27 -48
  168. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  169. msprobe/pytorch/monitor/optimizer_collect.py +52 -14
  170. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  171. msprobe/pytorch/monitor/utils.py +77 -6
  172. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  173. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  174. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  175. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  176. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  177. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  178. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  179. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  180. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  181. msprobe/pytorch/service.py +176 -106
  182. msprobe/visualization/builder/graph_builder.py +62 -5
  183. msprobe/visualization/builder/msprobe_adapter.py +24 -2
  184. msprobe/visualization/compare/graph_comparator.py +64 -14
  185. msprobe/visualization/compare/mode_adapter.py +1 -15
  186. msprobe/visualization/graph/base_node.py +12 -17
  187. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  188. msprobe/visualization/graph/graph.py +9 -0
  189. msprobe/visualization/graph_service.py +97 -23
  190. msprobe/visualization/utils.py +14 -29
  191. msprobe/pytorch/functional/module_dump.py +0 -84
  192. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  193. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
  194. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
  195. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  196. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  197. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -24,15 +24,20 @@ from msprobe.core.common.utils import CompareException
24
24
  from msprobe.core.common.file_utils import get_json_contents, write_csv
25
25
  import torch
26
26
  from msprobe.core.common.const import CompareConst
27
- from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \
28
- get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \
29
- get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \
30
- check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err
27
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_register import StandardRegistry
28
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.absolute_threshold import AbsolutethdCompare
29
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.benchmark_compare import BenchmarkCompare
30
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.ulp_compare import UlpCompare
31
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.binary_consistency import BinaryCompare
32
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.thousandth_standard import ThousandthStdCompare
33
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.accumulative_error_compare import AccumulativeErrorCompare
34
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_input import CompareInput
35
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_err, get_max_abs_err, get_rel_err_ratio, \
36
+ cosine_sim, get_rel_err_origin, get_abs_bench_with_eps, compare_bool_tensor
31
37
  from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
32
38
  from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
33
39
  from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \
34
- DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, absolute_standard_api, binary_standard_api, \
35
- ulp_standard_api, thousandth_standard_api, apis_threshold
40
+ DETAIL_TEST_ROWS, BENCHMARK_COMPARE_SUPPORT_LIST
36
41
  from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
37
42
  from msprobe.pytorch.common.log import logger
38
43
 
@@ -42,6 +47,7 @@ ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'b
42
47
 
43
48
 
44
49
  INDEX_TEST_RESULT_GROUP = 3
50
+ BACKWARD_RESULT_GROUP = 4
45
51
  INDEX_FIRST_GROUP = 0
46
52
  INDEX_MESSAGE = -1
47
53
 
@@ -66,6 +72,8 @@ class Comparator:
66
72
  self.detail_save_path_list = \
67
73
  [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
68
74
 
75
+ self.registry = self._register_compare_func()
76
+
69
77
  if not is_continue_run_ut:
70
78
  self.write_csv_title()
71
79
  if stack_info_json_path:
@@ -101,22 +109,6 @@ class Comparator:
101
109
  compare_column.error_rate = 0
102
110
  return CompareConst.PASS, compare_column, ""
103
111
 
104
- @staticmethod
105
- def _compare_bool_tensor(bench_output, device_output):
106
- error_nums = (bench_output != device_output).sum()
107
- if bench_output.size == 0:
108
- return CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result."
109
- error_rate = float(error_nums / bench_output.size)
110
- result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
111
- return error_rate, result, ""
112
-
113
- @staticmethod
114
- def _get_absolute_threshold_attribute(api_name, dtype):
115
- small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value')
116
- small_value_atol = apis_threshold.get(api_name).get(dtype).get('small_value_atol')
117
- rtol = apis_threshold.get(api_name).get(dtype).get('rtol')
118
- return small_value_threshold, small_value_atol, rtol
119
-
120
112
  @staticmethod
121
113
  def _get_run_ut_detail(test_result):
122
114
  """get run_ut detail before write to csv, called by online run_ut"""
@@ -143,6 +135,36 @@ class Comparator:
143
135
  test_rows.append([subject] + list(test_subject))
144
136
  return test_rows
145
137
 
138
+ @staticmethod
139
+ def _binary_standard_compare(input_data):
140
+ binary_compare = BinaryCompare(input_data)
141
+ binary_compare.compare()
142
+
143
+ @staticmethod
144
+ def _thousandth_standard_compare(input_data):
145
+ thousandth_compare = ThousandthStdCompare(input_data)
146
+ thousandth_compare.compare()
147
+
148
+ @staticmethod
149
+ def _absolute_standard_compare(input_data):
150
+ absolute_compare = AbsolutethdCompare(input_data)
151
+ absolute_compare.compare()
152
+
153
+ @staticmethod
154
+ def _ulp_compare(input_data):
155
+ ulp_compare = UlpCompare(input_data)
156
+ ulp_compare.compare()
157
+
158
+ @staticmethod
159
+ def _benchmark_compare(input_data):
160
+ benchmark_compare = BenchmarkCompare(input_data)
161
+ benchmark_compare.compare()
162
+
163
+ @staticmethod
164
+ def _accumulative_error_compare(input_data):
165
+ accumulative_error_compare = AccumulativeErrorCompare(input_data)
166
+ accumulative_error_compare.compare()
167
+
146
168
  def write_csv_title(self):
147
169
  summary_test_rows = [
148
170
  [self.COLUMN_API_NAME,
@@ -163,6 +185,8 @@ class Comparator:
163
185
  df_row = list(test_result[:INDEX_TEST_RESULT_GROUP])
164
186
  if test_result[1] == CompareConst.SKIP:
165
187
  df_row.append(test_result[INDEX_TEST_RESULT_GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
188
+ elif test_result[2] == CompareConst.SKIP:
189
+ df_row.append(test_result[BACKWARD_RESULT_GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
166
190
  if self.stack_info:
167
191
  stack_info = "\n".join(self.stack_info[name])
168
192
  df_row.append(stack_info)
@@ -211,6 +235,7 @@ class Comparator:
211
235
  if backward_message:
212
236
  backward_column = CompareColumn()
213
237
  bwd_compare_alg_results = [backward_column.to_column_value(CompareConst.SKIP, backward_message)]
238
+ bwd_success_status = CompareConst.SKIP
214
239
  else:
215
240
  bwd_success_status = bwd_success_status if bwd_compare_alg_results is not None else CompareConst.SPACE
216
241
  result_info = ResultInfo(full_api_name,
@@ -226,6 +251,16 @@ class Comparator:
226
251
  return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
227
252
  or bwd_success_status == CompareConst.SPACE
228
253
 
254
+ def _register_compare_func(self):
255
+ registry = StandardRegistry()
256
+ registry.register(CompareConst.ABSOLUTE_THRESHOLD, self._absolute_standard_compare)
257
+ registry.register(CompareConst.BINARY_CONSISTENCY, self._binary_standard_compare)
258
+ registry.register(CompareConst.ULP_COMPARE, self._ulp_compare)
259
+ registry.register(CompareConst.THOUSANDTH_STANDARD, self._thousandth_standard_compare)
260
+ registry.register(CompareConst.BENCHMARK, self._benchmark_compare)
261
+ registry.register(CompareConst.ACCUMULATIVE_ERROR_COMPARE, self._accumulative_error_compare)
262
+ return registry
263
+
229
264
  def _compare_core_wrapper(self, api_name, bench_output, device_output):
230
265
  detailed_result_total = []
231
266
  test_final_success = CompareConst.PASS
@@ -308,11 +343,13 @@ class Comparator:
308
343
  return CompareConst.ERROR, compare_column, f"Bench out dtype is {bench_output.dtype} but " \
309
344
  f"npu output dtype is {device_output.dtype}, cannot compare."
310
345
  message = ""
346
+ if bench_output.size == 0:
347
+ return CompareConst.ERROR, compare_column, "There is not bench calculation result."
311
348
  if bench_output.dtype in [bool, np.uint8, np.int8, np.int16, np.uint16, np.uint32, np.int32,
312
349
  np.int64, np.uint64]:
313
350
  message += f"Compare algorithm is not supported for {bench_output.dtype} data. " \
314
351
  f"Only judged by Error Rate."
315
- err_rate, status, msg = self._compare_bool_tensor(bench_output, device_output)
352
+ err_rate, status, msg = compare_bool_tensor(bench_output, device_output)
316
353
  message += msg + "\n"
317
354
  compare_column.error_rate = err_rate
318
355
  return status, compare_column, message
@@ -321,56 +358,20 @@ class Comparator:
321
358
  compare_column, npu_dtype)
322
359
  return status, compare_column, message
323
360
 
361
+ def _perform_comparison(self, api_name, input_data):
362
+ comparison_func = self.registry.get_comparison_function(api_name, None)
363
+ comparison_func(input_data)
364
+
324
365
  def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype):
325
366
  message = ""
326
- abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype)
367
+ _, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype)
327
368
  abs_err = get_abs_err(bench_output, device_output)
328
369
  rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps)
329
- if api_name in thousandth_standard_api:
330
- thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
331
- compare_column.rel_err_thousandth = thousand_res
370
+ input_data = CompareInput(bench_output, device_output, compare_column, dtype, rel_err_orign)
332
371
  if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST:
333
- both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output)
334
- if api_name in binary_standard_api:
335
- err_rate, _, _ = self._compare_bool_tensor(bench_output, device_output)
336
- compare_column.error_rate = err_rate
337
- elif api_name in absolute_standard_api:
338
- small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute(
339
- api_name, str(dtype))
340
- rel_err = abs_err / abs_bench_with_eps
341
- small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold)
342
- normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask))
343
- compare_column.inf_nan_error_ratio = check_inf_nan_value(inf_nan_mask, bench_output, device_output,
344
- dtype, rtol)
345
- compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol)
346
- compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol)
347
- elif api_name in ulp_standard_api:
348
- if bench_output.size == 0:
349
- compare_column.max_ulp_error = 0
350
- compare_column.mean_ulp_error = 0
351
- compare_column.ulp_error_proportion = 0
352
- else:
353
- ulp_err = get_ulp_err(bench_output, device_output, dtype)
354
- compare_column.max_ulp_error = np.max(ulp_err)
355
- compare_column.mean_ulp_error = np.mean(ulp_err)
356
- if dtype == torch.float32:
357
- compare_column.ulp_error_proportion = \
358
- np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size
359
- else:
360
- compare_column.ulp_error_proportion = \
361
- np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size
362
- else:
363
- dtype_config = precision_configs.get(dtype)
364
- small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, dtype_config['small_value'][0])
365
- abs_err_greater_mask = np.greater(abs_err, dtype_config['small_value_atol'][0])
366
- compare_column.small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask)
367
- rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask)
368
- compare_column.rmse = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask))
369
- compare_column.eb = get_error_balance(bench_output, device_output)
370
- if rel_err.size == 0:
371
- return CompareConst.ERROR, compare_column, "Relative error result list is empty."
372
- compare_column.max_rel_error = get_max_rel_err(rel_err)
373
- compare_column.mean_rel_error = get_mean_rel_err(rel_err)
372
+ self._perform_comparison(api_name, input_data)
373
+ else:
374
+ message += f"The data type {dtype} is not supported for new precision standard."
374
375
 
375
376
  cos_res, cos_status, msg = cosine_sim(bench_output, device_output)
376
377
  compare_column.cosine_sim = cos_res
@@ -16,9 +16,17 @@
16
16
  # limitations under the License.
17
17
 
18
18
  from msprobe.core.common.const import CompareConst
19
+ from msprobe.pytorch.common.log import logger
19
20
 
20
21
 
21
22
  class CompareColumn:
23
+ __slots__ = [
24
+ 'bench_type', 'npu_type', 'shape', 'cosine_sim', 'max_abs_err', 'rel_err_hundredth',
25
+ 'rel_err_ten_thousandth', 'inf_nan_error_ratio', 'rel_err_ratio', 'abs_err_ratio',
26
+ 'small_value_err_ratio', 'max_rel_error', 'mean_rel_error', 'rmse', 'eb', 'max_ulp_error',
27
+ 'mean_ulp_error', 'ulp_error_proportion', 'error_rate', 'rel_err_thousandth'
28
+ ]
29
+
22
30
  def __init__(self):
23
31
  self.bench_type = CompareConst.SPACE
24
32
  self.npu_type = CompareConst.SPACE
@@ -41,6 +49,24 @@ class CompareColumn:
41
49
  self.mean_ulp_error = CompareConst.SPACE
42
50
  self.ulp_error_proportion = CompareConst.SPACE
43
51
 
52
+ def update(self, metrics):
53
+ """
54
+ Updates the object's attributes with the provided metrics.
55
+
56
+ Args:
57
+ metrics (dict): A dictionary containing attribute names and their corresponding values.
58
+
59
+ Raises:
60
+ AttributeError: If the metric key is not a valid attribute of CompareColumn.
61
+ """
62
+ for key, value in metrics.items():
63
+ if value is None:
64
+ continue
65
+ if key not in self.__slots__:
66
+ logger.error(f"The key '{key}' is not a valid attribute of CompareColumn.")
67
+ continue
68
+ setattr(self, key, value)
69
+
44
70
  def to_column_value(self, is_pass, message):
45
71
  return [self.bench_type, self.npu_type, self.shape, self.cosine_sim, self.max_abs_err, self.rel_err_hundredth,
46
72
  self.rel_err_thousandth, self.rel_err_ten_thousandth, self.error_rate, self.eb, self.rmse,
@@ -50,6 +76,16 @@ class CompareColumn:
50
76
 
51
77
 
52
78
  class ApiPrecisionOutputColumn:
79
+ __slots__ = [
80
+ 'api_name', 'small_value_err_ratio', 'small_value_err_status', 'rmse_ratio', 'rmse_status',
81
+ 'max_rel_err_ratio', 'max_rel_err_status', 'mean_rel_err_ratio', 'mean_rel_err_status', 'eb_ratio',
82
+ 'eb_status', 'inf_nan_error_ratio', 'inf_nan_error_ratio_status', 'rel_err_ratio',
83
+ 'rel_err_ratio_status', 'abs_err_ratio', 'abs_err_ratio_status', 'error_rate', 'error_rate_status',
84
+ 'mean_ulp_err', 'ulp_err_proportion', 'ulp_err_proportion_ratio', 'ulp_err_status',
85
+ 'rel_err_thousandth', 'rel_err_thousandth_status', 'compare_result', 'compare_algorithm',
86
+ 'compare_message'
87
+ ]
88
+
53
89
  def __init__(self):
54
90
  self.api_name = CompareConst.SPACE
55
91
  self.small_value_err_ratio = CompareConst.SPACE
@@ -80,6 +116,24 @@ class ApiPrecisionOutputColumn:
80
116
  self.compare_algorithm = CompareConst.SPACE
81
117
  self.compare_message = CompareConst.SPACE
82
118
 
119
+ def update(self, metrics):
120
+ """
121
+ Updates the object's attributes with the provided metrics.
122
+
123
+ Args:
124
+ metrics (dict): A dictionary containing attribute names and their corresponding values.
125
+
126
+ Raises:
127
+ AttributeError: If the metric key is not a valid attribute of CompareColumn.
128
+ """
129
+ for key, value in metrics.items():
130
+ if value is None:
131
+ continue
132
+ if key not in self.__slots__:
133
+ logger.error("The key '%s' is not a valid attribute of CompareColumn.", key)
134
+ continue
135
+ setattr(self, key, value)
136
+
83
137
  def to_column_value(self):
84
138
  return [self.api_name, self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
85
139
  self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
@@ -0,0 +1,51 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import numpy as np
19
+
20
+
21
+ class CompareInput:
22
+ """
23
+ A class to encapsulate the input data required for comparison operations.
24
+
25
+ Attributes:
26
+ bench_output (np.ndarray): The benchmark output values.
27
+ device_output (np.ndarray): The device output values.
28
+ compare_column (class): A clasee to store and update comparison metrics.
29
+ dtype (type, optional): The data type of the outputs. Defaults to None.
30
+ rel_err_orign (float or array-like, optional): The original relative error values. Defaults to None.
31
+
32
+ Methods:
33
+ __init__(bench_output, device_output, compare_column, dtype, rel_err_orign):
34
+ Initializes an instance of CompareInput.
35
+ """
36
+ def __init__(self, bench_output, device_output, compare_column, dtype=None, rel_err_orign=None):
37
+ self.bench_output = bench_output
38
+ self.device_output = device_output
39
+ if not isinstance(bench_output, np.ndarray) or not isinstance(device_output, np.ndarray):
40
+ raise TypeError("The input should be numpy array")
41
+ self.compare_column = compare_column
42
+ self.dtype = dtype
43
+ self.rel_err_orign = rel_err_orign
44
+
45
+
46
+ class PrecisionCompareInput:
47
+ def __init__(self, row_npu, row_gpu, dtype, compare_column):
48
+ self.row_npu = row_npu
49
+ self.row_gpu = row_gpu
50
+ self.dtype = dtype
51
+ self.compare_column = compare_column
@@ -43,10 +43,7 @@ absolute_standard_api = apis.get('AbsoluteThreshStandard')
43
43
  binary_standard_api = apis.get('BinaryCompareStandard')
44
44
  ulp_standard_api = apis.get('ULPStandard')
45
45
  thousandth_standard_api = apis.get('ThousandthStandard')
46
-
47
-
48
- threshold_yaml_path = os.path.join(cur_path, "api_precision_threshold.yaml")
49
- apis_threshold = load_yaml(threshold_yaml_path)
46
+ accumulative_error_standard_api = apis.get('AccumulativeErrorStandard')
50
47
 
51
48
 
52
49
  DETAIL_TEST_ROWS = [
@@ -134,6 +131,7 @@ ULP_PARAMETERS = {
134
131
  class ApiPrecisionCompareColumn:
135
132
  API_NAME = 'API Name'
136
133
  DEVICE_DTYPE = 'DEVICE Dtype'
134
+ SHAPE = 'Shape'
137
135
  SMALL_VALUE_ERROR_RATE = '小值域错误占比'
138
136
  RMSE = '均方根误差'
139
137
  MAX_REL_ERR = '相对误差最大值'
@@ -1,8 +1,9 @@
1
1
  #!/usr/bin/env python3
2
2
  # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
7
  # you may not use this file except in compliance with the License.
7
8
  # You may obtain a copy of the License at
8
9
  #
@@ -13,17 +14,18 @@
13
14
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
15
  # See the License for the specific language governing permissions and
15
16
  # limitations under the License.
16
- """
17
+
17
18
  import argparse
18
19
  import json
19
20
  import os
20
21
  import re
22
+
21
23
  import math
22
24
  import numpy as np
23
25
  import torch
24
26
 
25
-
26
- from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import binary_standard_api, absolute_standard_api, ulp_standard_api, thousandth_standard_api
27
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import binary_standard_api, absolute_standard_api, \
28
+ ulp_standard_api, thousandth_standard_api
27
29
  from msprobe.core.common.file_utils import FileOpen, load_json, save_json
28
30
  from msprobe.core.common.utils import check_file_or_directory_path, check_op_str_pattern_valid, is_int
29
31
  from msprobe.core.common.const import Const, MonitorConst, MsgConst
@@ -78,6 +80,7 @@ class APIInfo:
78
80
  def is_supported_type(self):
79
81
  return self.api_type in OPERATOR_TYPE
80
82
 
83
+
81
84
  class CommonConfig:
82
85
  def __init__(self, json_config):
83
86
  self.dump_json_path = json_config.get('dump_json_path')
@@ -147,6 +150,7 @@ class CommonConfig:
147
150
  if not is_int(self.iter_times):
148
151
  raise ValueError(f'iter_times is invalid, it should be an int')
149
152
 
153
+
150
154
  class APIExtractor:
151
155
  def __init__(self, api_name, dump_json_path, output_file):
152
156
  self.api_name = api_name
@@ -186,6 +190,7 @@ class APIExtractor:
186
190
  elif DATA_NAME in data:
187
191
  data[DATA_NAME] = os.path.join(dump_data_dir, data[DATA_NAME])
188
192
 
193
+
189
194
  class OperatorScriptGenerator:
190
195
  def __init__(self, common_config, args_info_forward, kwargs_info_forward, args_info_backward):
191
196
  self.common_config = common_config
@@ -238,7 +243,8 @@ class OperatorScriptGenerator:
238
243
  ordinal_number: how many times the same api has been called
239
244
  direction_status: forward
240
245
  random_seed: if mode is random_data, random seed is random_seed
241
- iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data, iter_times does not matter
246
+ iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data,
247
+ iter_times does not matter
242
248
  args_element_assignment: code for args assignment
243
249
  args_list_generator_device: code for generate args list on device
244
250
  args_list_generator_bench: code for generate args list on bench
@@ -267,17 +273,25 @@ class OperatorScriptGenerator:
267
273
  internal_settings["iter_times"] = 1
268
274
  else:
269
275
  internal_settings["iter_times"] = self.common_config.iter_times
270
- internal_settings["args_element_assignment"] = self.generate_args_element_assignment_code(self.args_info_forward)
271
- internal_settings["args_list_generator_device"] = self.generate_args_list(self.args_info_forward, flag_device=True)
272
- internal_settings["args_list_generator_bench"] = self.generate_args_list(self.args_info_forward, flag_device=False)
273
- internal_settings["kwargs_value_assignment"] = self.generate_kwargs_value_assignment_code(self.kwargs_info_forward)
274
- internal_settings["kwargs_dict_generator_device"] = self.generate_kwargs_dict(self.kwargs_info_forward, flag_device=True)
275
- internal_settings["kwargs_dict_generator_bench"] = self.generate_kwargs_dict(self.kwargs_info_forward, flag_device=False)
276
+ internal_settings["args_element_assignment"] = \
277
+ self.generate_args_element_assignment_code(self.args_info_forward)
278
+ internal_settings["args_list_generator_device"] = \
279
+ self.generate_args_list(self.args_info_forward, flag_device=True)
280
+ internal_settings["args_list_generator_bench"] = \
281
+ self.generate_args_list(self.args_info_forward, flag_device=False)
282
+ internal_settings["kwargs_value_assignment"] = \
283
+ self.generate_kwargs_value_assignment_code(self.kwargs_info_forward)
284
+ internal_settings["kwargs_dict_generator_device"] = \
285
+ self.generate_kwargs_dict(self.kwargs_info_forward, flag_device=True)
286
+ internal_settings["kwargs_dict_generator_bench"] = \
287
+ self.generate_kwargs_dict(self.kwargs_info_forward, flag_device=False)
276
288
  if self.common_config.propagation == Const.BACKWARD:
277
289
  internal_settings["args_element_assignment_backward"] = self.generate_args_element_assignment_code(
278
290
  self.args_info_backward)
279
- internal_settings["args_list_generator_device_backward"] = self.generate_args_list(self.args_info_backward, flag_device=True)
280
- internal_settings["args_list_generator_bench_backward"] = self.generate_args_list(self.args_info_backward, flag_device=False)
291
+ internal_settings["args_list_generator_device_backward"] = \
292
+ self.generate_args_list(self.args_info_backward, flag_device=True)
293
+ internal_settings["args_list_generator_bench_backward"] = \
294
+ self.generate_args_list(self.args_info_backward, flag_device=False)
281
295
  else:
282
296
  internal_settings["args_element_assignment_backward"] = ''
283
297
  internal_settings["args_list_generator_device_backward"] = ''
@@ -290,12 +304,15 @@ class OperatorScriptGenerator:
290
304
  args_element_assignment = ""
291
305
  for index, arg in enumerate(args_info):
292
306
  if isinstance(arg, (list, tuple)):
293
- new_args_element_assignment = self.recursive_args_element_assignment(arg, name_number + "_" + str(index))
307
+ new_args_element_assignment = \
308
+ self.recursive_args_element_assignment(arg, name_number + "_" + str(index))
294
309
  args_element_assignment += new_args_element_assignment
295
310
  else:
296
311
  arg["parameter_name"] = "arg" + name_number + "_" + str(index)
297
- args_element_assignment += " " + "arg_info" + name_number + "_" + str(index) + " = " + "{}".format(str(arg)) + MsgConst.SPECIAL_CHAR[0]
298
- args_element_assignment += " " + "arg" + name_number + "_" + str(index) + " = " + "generate_data(arg_info" + name_number + "_" + str(index) + ")" + MsgConst.SPECIAL_CHAR[0]
312
+ args_element_assignment += " " + "arg_info" + name_number + "_" + str(index) + " = " + \
313
+ "{}".format(str(arg)) + MsgConst.SPECIAL_CHAR[0]
314
+ args_element_assignment += " " + "arg" + name_number + "_" + str(index) + " = " + \
315
+ "generate_data(arg_info" + name_number + "_" + str(index) + ")" + MsgConst.SPECIAL_CHAR[0]
299
316
  return args_element_assignment
300
317
 
301
318
 
@@ -320,7 +337,8 @@ class OperatorScriptGenerator:
320
337
  args_list_generator += ".to(device)"
321
338
  if flag_bench:
322
339
  args_list_generator += '.to(torch.device("cpu"))'
323
- args_list_generator += ".to(RAISE_PRECISION.get(str(" + arg.get("parameter_name") + ".dtype), " + arg.get("parameter_name") + ".dtype))"
340
+ args_list_generator += ".to(RAISE_PRECISION.get(str(" + arg.get("parameter_name") + \
341
+ ".dtype), " + arg.get("parameter_name") + ".dtype))"
324
342
  args_list_generator += Const.COMMA
325
343
  return args_list_generator
326
344
 
@@ -338,12 +356,15 @@ class OperatorScriptGenerator:
338
356
  if info.get("type") == "torch.device" or info.get("type") == "torch.dtype":
339
357
  kwargs_value_assignment += " " + "kwarg_" + key_name + name_number + " = " + info.get("value")
340
358
  else:
341
- kwargs_value_assignment += " " + "kwarg_info_" + key_name + name_number + " = " + "{}".format(str(info)) + MsgConst.SPECIAL_CHAR[0]
342
- kwargs_value_assignment += " " + "kwarg_" + key_name + name_number + " = " + "generate_data(kwarg_info_" + key_name + name_number + ")" + MsgConst.SPECIAL_CHAR[0]
359
+ kwargs_value_assignment += " " + "kwarg_info_" + key_name + name_number + " = " + \
360
+ "{}".format(str(info)) + MsgConst.SPECIAL_CHAR[0]
361
+ kwargs_value_assignment += " " + "kwarg_" + key_name + name_number + " = " + \
362
+ "generate_data(kwarg_info_" + key_name + name_number + ")" + MsgConst.SPECIAL_CHAR[0]
343
363
  info["parameter_name"] = "kwarg_" + key_name + name_number
344
364
  else:
345
365
  for index, arg in enumerate(info):
346
- new_kwargs_value_assignment = self.recursive_kwargs_value_assignment(arg, key_name, name_number + "_" + str(index))
366
+ new_kwargs_value_assignment = self.recursive_kwargs_value_assignment(arg, key_name, name_number + \
367
+ "_" + str(index))
347
368
  kwargs_value_assignment += new_kwargs_value_assignment
348
369
  return kwargs_value_assignment
349
370
 
@@ -363,7 +384,8 @@ class OperatorScriptGenerator:
363
384
  kwargs_dict_generator += ".to(device)"
364
385
  if flag_bench:
365
386
  kwargs_dict_generator += '.to(torch.device("cpu"))'
366
- kwargs_dict_generator += ".to(RAISE_PRECISION.get(str(" + info.get("parameter_name") + ".dtype), " + info.get("parameter_name") + ".dtype))"
387
+ kwargs_dict_generator += ".to(RAISE_PRECISION.get(str(" + info.get("parameter_name") + \
388
+ ".dtype), " + info.get("parameter_name") + ".dtype))"
367
389
  else:
368
390
  (left_bracket, right_bracket) = ("[", "]") if isinstance(info, list) else ("(", ")")
369
391
  kwargs_dict_generator += left_bracket
@@ -386,13 +408,14 @@ class OperatorScriptGenerator:
386
408
 
387
409
 
388
410
 
389
- def op_generator_parser(parser):
411
+ def _op_generator_parser(parser):
390
412
  parser.add_argument("-i", "--config_input", dest="config_input", default='', type=str,
391
413
  help="<Optional> Path of config json file", required=True)
392
414
  parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str,
393
415
  help="<Required> Path of extract api_name.json.",
394
416
  required=True)
395
417
 
418
+
396
419
  def parse_json_config(json_file_path):
397
420
  if not json_file_path:
398
421
  config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@@ -401,11 +424,8 @@ def parse_json_config(json_file_path):
401
424
  common_config = CommonConfig(json_config)
402
425
  return common_config
403
426
 
404
- def main():
405
- parser = argparse.ArgumentParser()
406
- op_generator_parser(parser)
407
- cmd_args = parser.parse_args()
408
427
 
428
+ def _run_operator_generate_commond(cmd_args):
409
429
  common_config = parse_json_config(cmd_args.config_input)
410
430
 
411
431
  if common_config.dump_json_path:
@@ -438,7 +458,8 @@ def main():
438
458
  internal_settings = op_generate.get_settings(api_full_name_forward)
439
459
 
440
460
  template_path = os.path.join(os.path.dirname(__file__), "operator_replication.template")
441
- operator_script_path = os.path.join(cmd_args.api_output_path, "{0}.py".format(internal_settings.get("api_full_name")))
461
+ operator_script_path = os.path.join(cmd_args.api_output_path,
462
+ "{0}.py".format(internal_settings.get("api_full_name")))
442
463
 
443
464
  try:
444
465
  with FileOpen(template_path, 'r') as ftemp, FileOpen(operator_script_path, 'w') as fout:
@@ -451,4 +472,7 @@ def main():
451
472
 
452
473
 
453
474
  if __name__ == "__main__":
454
- main()
475
+ parser = argparse.ArgumentParser()
476
+ _op_generator_parser(parser)
477
+ cmd_args = parser.parse_args()
478
+ _run_operator_generate_commond(cmd_args)