mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -24,15 +24,20 @@ from msprobe.core.common.utils import CompareException
|
|
|
24
24
|
from msprobe.core.common.file_utils import get_json_contents, write_csv
|
|
25
25
|
import torch
|
|
26
26
|
from msprobe.core.common.const import CompareConst
|
|
27
|
-
from msprobe.pytorch.api_accuracy_checker.
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
27
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_register import StandardRegistry
|
|
28
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.absolute_threshold import AbsolutethdCompare
|
|
29
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.benchmark_compare import BenchmarkCompare
|
|
30
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.ulp_compare import UlpCompare
|
|
31
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.binary_consistency import BinaryCompare
|
|
32
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.thousandth_standard import ThousandthStdCompare
|
|
33
|
+
from msprobe.pytorch.api_accuracy_checker.precision_standard.accumulative_error_compare import AccumulativeErrorCompare
|
|
34
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_input import CompareInput
|
|
35
|
+
from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_err, get_max_abs_err, get_rel_err_ratio, \
|
|
36
|
+
cosine_sim, get_rel_err_origin, get_abs_bench_with_eps, compare_bool_tensor
|
|
31
37
|
from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
|
|
32
38
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
|
|
33
39
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \
|
|
34
|
-
DETAIL_TEST_ROWS,
|
|
35
|
-
ulp_standard_api, thousandth_standard_api, apis_threshold
|
|
40
|
+
DETAIL_TEST_ROWS, BENCHMARK_COMPARE_SUPPORT_LIST
|
|
36
41
|
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
|
|
37
42
|
from msprobe.pytorch.common.log import logger
|
|
38
43
|
|
|
@@ -42,6 +47,7 @@ ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'b
|
|
|
42
47
|
|
|
43
48
|
|
|
44
49
|
INDEX_TEST_RESULT_GROUP = 3
|
|
50
|
+
BACKWARD_RESULT_GROUP = 4
|
|
45
51
|
INDEX_FIRST_GROUP = 0
|
|
46
52
|
INDEX_MESSAGE = -1
|
|
47
53
|
|
|
@@ -66,6 +72,8 @@ class Comparator:
|
|
|
66
72
|
self.detail_save_path_list = \
|
|
67
73
|
[self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
|
|
68
74
|
|
|
75
|
+
self.registry = self._register_compare_func()
|
|
76
|
+
|
|
69
77
|
if not is_continue_run_ut:
|
|
70
78
|
self.write_csv_title()
|
|
71
79
|
if stack_info_json_path:
|
|
@@ -101,22 +109,6 @@ class Comparator:
|
|
|
101
109
|
compare_column.error_rate = 0
|
|
102
110
|
return CompareConst.PASS, compare_column, ""
|
|
103
111
|
|
|
104
|
-
@staticmethod
|
|
105
|
-
def _compare_bool_tensor(bench_output, device_output):
|
|
106
|
-
error_nums = (bench_output != device_output).sum()
|
|
107
|
-
if bench_output.size == 0:
|
|
108
|
-
return CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result."
|
|
109
|
-
error_rate = float(error_nums / bench_output.size)
|
|
110
|
-
result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
|
|
111
|
-
return error_rate, result, ""
|
|
112
|
-
|
|
113
|
-
@staticmethod
|
|
114
|
-
def _get_absolute_threshold_attribute(api_name, dtype):
|
|
115
|
-
small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value')
|
|
116
|
-
small_value_atol = apis_threshold.get(api_name).get(dtype).get('small_value_atol')
|
|
117
|
-
rtol = apis_threshold.get(api_name).get(dtype).get('rtol')
|
|
118
|
-
return small_value_threshold, small_value_atol, rtol
|
|
119
|
-
|
|
120
112
|
@staticmethod
|
|
121
113
|
def _get_run_ut_detail(test_result):
|
|
122
114
|
"""get run_ut detail before write to csv, called by online run_ut"""
|
|
@@ -143,6 +135,36 @@ class Comparator:
|
|
|
143
135
|
test_rows.append([subject] + list(test_subject))
|
|
144
136
|
return test_rows
|
|
145
137
|
|
|
138
|
+
@staticmethod
|
|
139
|
+
def _binary_standard_compare(input_data):
|
|
140
|
+
binary_compare = BinaryCompare(input_data)
|
|
141
|
+
binary_compare.compare()
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
def _thousandth_standard_compare(input_data):
|
|
145
|
+
thousandth_compare = ThousandthStdCompare(input_data)
|
|
146
|
+
thousandth_compare.compare()
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
149
|
+
def _absolute_standard_compare(input_data):
|
|
150
|
+
absolute_compare = AbsolutethdCompare(input_data)
|
|
151
|
+
absolute_compare.compare()
|
|
152
|
+
|
|
153
|
+
@staticmethod
|
|
154
|
+
def _ulp_compare(input_data):
|
|
155
|
+
ulp_compare = UlpCompare(input_data)
|
|
156
|
+
ulp_compare.compare()
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def _benchmark_compare(input_data):
|
|
160
|
+
benchmark_compare = BenchmarkCompare(input_data)
|
|
161
|
+
benchmark_compare.compare()
|
|
162
|
+
|
|
163
|
+
@staticmethod
|
|
164
|
+
def _accumulative_error_compare(input_data):
|
|
165
|
+
accumulative_error_compare = AccumulativeErrorCompare(input_data)
|
|
166
|
+
accumulative_error_compare.compare()
|
|
167
|
+
|
|
146
168
|
def write_csv_title(self):
|
|
147
169
|
summary_test_rows = [
|
|
148
170
|
[self.COLUMN_API_NAME,
|
|
@@ -163,6 +185,8 @@ class Comparator:
|
|
|
163
185
|
df_row = list(test_result[:INDEX_TEST_RESULT_GROUP])
|
|
164
186
|
if test_result[1] == CompareConst.SKIP:
|
|
165
187
|
df_row.append(test_result[INDEX_TEST_RESULT_GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
|
|
188
|
+
elif test_result[2] == CompareConst.SKIP:
|
|
189
|
+
df_row.append(test_result[BACKWARD_RESULT_GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
|
|
166
190
|
if self.stack_info:
|
|
167
191
|
stack_info = "\n".join(self.stack_info[name])
|
|
168
192
|
df_row.append(stack_info)
|
|
@@ -211,6 +235,7 @@ class Comparator:
|
|
|
211
235
|
if backward_message:
|
|
212
236
|
backward_column = CompareColumn()
|
|
213
237
|
bwd_compare_alg_results = [backward_column.to_column_value(CompareConst.SKIP, backward_message)]
|
|
238
|
+
bwd_success_status = CompareConst.SKIP
|
|
214
239
|
else:
|
|
215
240
|
bwd_success_status = bwd_success_status if bwd_compare_alg_results is not None else CompareConst.SPACE
|
|
216
241
|
result_info = ResultInfo(full_api_name,
|
|
@@ -226,6 +251,16 @@ class Comparator:
|
|
|
226
251
|
return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
|
|
227
252
|
or bwd_success_status == CompareConst.SPACE
|
|
228
253
|
|
|
254
|
+
def _register_compare_func(self):
|
|
255
|
+
registry = StandardRegistry()
|
|
256
|
+
registry.register(CompareConst.ABSOLUTE_THRESHOLD, self._absolute_standard_compare)
|
|
257
|
+
registry.register(CompareConst.BINARY_CONSISTENCY, self._binary_standard_compare)
|
|
258
|
+
registry.register(CompareConst.ULP_COMPARE, self._ulp_compare)
|
|
259
|
+
registry.register(CompareConst.THOUSANDTH_STANDARD, self._thousandth_standard_compare)
|
|
260
|
+
registry.register(CompareConst.BENCHMARK, self._benchmark_compare)
|
|
261
|
+
registry.register(CompareConst.ACCUMULATIVE_ERROR_COMPARE, self._accumulative_error_compare)
|
|
262
|
+
return registry
|
|
263
|
+
|
|
229
264
|
def _compare_core_wrapper(self, api_name, bench_output, device_output):
|
|
230
265
|
detailed_result_total = []
|
|
231
266
|
test_final_success = CompareConst.PASS
|
|
@@ -308,11 +343,13 @@ class Comparator:
|
|
|
308
343
|
return CompareConst.ERROR, compare_column, f"Bench out dtype is {bench_output.dtype} but " \
|
|
309
344
|
f"npu output dtype is {device_output.dtype}, cannot compare."
|
|
310
345
|
message = ""
|
|
346
|
+
if bench_output.size == 0:
|
|
347
|
+
return CompareConst.ERROR, compare_column, "There is not bench calculation result."
|
|
311
348
|
if bench_output.dtype in [bool, np.uint8, np.int8, np.int16, np.uint16, np.uint32, np.int32,
|
|
312
349
|
np.int64, np.uint64]:
|
|
313
350
|
message += f"Compare algorithm is not supported for {bench_output.dtype} data. " \
|
|
314
351
|
f"Only judged by Error Rate."
|
|
315
|
-
err_rate, status, msg =
|
|
352
|
+
err_rate, status, msg = compare_bool_tensor(bench_output, device_output)
|
|
316
353
|
message += msg + "\n"
|
|
317
354
|
compare_column.error_rate = err_rate
|
|
318
355
|
return status, compare_column, message
|
|
@@ -321,56 +358,20 @@ class Comparator:
|
|
|
321
358
|
compare_column, npu_dtype)
|
|
322
359
|
return status, compare_column, message
|
|
323
360
|
|
|
361
|
+
def _perform_comparison(self, api_name, input_data):
|
|
362
|
+
comparison_func = self.registry.get_comparison_function(api_name, None)
|
|
363
|
+
comparison_func(input_data)
|
|
364
|
+
|
|
324
365
|
def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype):
|
|
325
366
|
message = ""
|
|
326
|
-
|
|
367
|
+
_, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype)
|
|
327
368
|
abs_err = get_abs_err(bench_output, device_output)
|
|
328
369
|
rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps)
|
|
329
|
-
|
|
330
|
-
thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
|
|
331
|
-
compare_column.rel_err_thousandth = thousand_res
|
|
370
|
+
input_data = CompareInput(bench_output, device_output, compare_column, dtype, rel_err_orign)
|
|
332
371
|
if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST:
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
compare_column.error_rate = err_rate
|
|
337
|
-
elif api_name in absolute_standard_api:
|
|
338
|
-
small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute(
|
|
339
|
-
api_name, str(dtype))
|
|
340
|
-
rel_err = abs_err / abs_bench_with_eps
|
|
341
|
-
small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold)
|
|
342
|
-
normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask))
|
|
343
|
-
compare_column.inf_nan_error_ratio = check_inf_nan_value(inf_nan_mask, bench_output, device_output,
|
|
344
|
-
dtype, rtol)
|
|
345
|
-
compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol)
|
|
346
|
-
compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol)
|
|
347
|
-
elif api_name in ulp_standard_api:
|
|
348
|
-
if bench_output.size == 0:
|
|
349
|
-
compare_column.max_ulp_error = 0
|
|
350
|
-
compare_column.mean_ulp_error = 0
|
|
351
|
-
compare_column.ulp_error_proportion = 0
|
|
352
|
-
else:
|
|
353
|
-
ulp_err = get_ulp_err(bench_output, device_output, dtype)
|
|
354
|
-
compare_column.max_ulp_error = np.max(ulp_err)
|
|
355
|
-
compare_column.mean_ulp_error = np.mean(ulp_err)
|
|
356
|
-
if dtype == torch.float32:
|
|
357
|
-
compare_column.ulp_error_proportion = \
|
|
358
|
-
np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size
|
|
359
|
-
else:
|
|
360
|
-
compare_column.ulp_error_proportion = \
|
|
361
|
-
np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size
|
|
362
|
-
else:
|
|
363
|
-
dtype_config = precision_configs.get(dtype)
|
|
364
|
-
small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, dtype_config['small_value'][0])
|
|
365
|
-
abs_err_greater_mask = np.greater(abs_err, dtype_config['small_value_atol'][0])
|
|
366
|
-
compare_column.small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask)
|
|
367
|
-
rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask)
|
|
368
|
-
compare_column.rmse = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask))
|
|
369
|
-
compare_column.eb = get_error_balance(bench_output, device_output)
|
|
370
|
-
if rel_err.size == 0:
|
|
371
|
-
return CompareConst.ERROR, compare_column, "Relative error result list is empty."
|
|
372
|
-
compare_column.max_rel_error = get_max_rel_err(rel_err)
|
|
373
|
-
compare_column.mean_rel_error = get_mean_rel_err(rel_err)
|
|
372
|
+
self._perform_comparison(api_name, input_data)
|
|
373
|
+
else:
|
|
374
|
+
message += f"The data type {dtype} is not supported for new precision standard."
|
|
374
375
|
|
|
375
376
|
cos_res, cos_status, msg = cosine_sim(bench_output, device_output)
|
|
376
377
|
compare_column.cosine_sim = cos_res
|
|
@@ -16,9 +16,17 @@
|
|
|
16
16
|
# limitations under the License.
|
|
17
17
|
|
|
18
18
|
from msprobe.core.common.const import CompareConst
|
|
19
|
+
from msprobe.pytorch.common.log import logger
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class CompareColumn:
|
|
23
|
+
__slots__ = [
|
|
24
|
+
'bench_type', 'npu_type', 'shape', 'cosine_sim', 'max_abs_err', 'rel_err_hundredth',
|
|
25
|
+
'rel_err_ten_thousandth', 'inf_nan_error_ratio', 'rel_err_ratio', 'abs_err_ratio',
|
|
26
|
+
'small_value_err_ratio', 'max_rel_error', 'mean_rel_error', 'rmse', 'eb', 'max_ulp_error',
|
|
27
|
+
'mean_ulp_error', 'ulp_error_proportion', 'error_rate', 'rel_err_thousandth'
|
|
28
|
+
]
|
|
29
|
+
|
|
22
30
|
def __init__(self):
|
|
23
31
|
self.bench_type = CompareConst.SPACE
|
|
24
32
|
self.npu_type = CompareConst.SPACE
|
|
@@ -41,6 +49,24 @@ class CompareColumn:
|
|
|
41
49
|
self.mean_ulp_error = CompareConst.SPACE
|
|
42
50
|
self.ulp_error_proportion = CompareConst.SPACE
|
|
43
51
|
|
|
52
|
+
def update(self, metrics):
|
|
53
|
+
"""
|
|
54
|
+
Updates the object's attributes with the provided metrics.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
metrics (dict): A dictionary containing attribute names and their corresponding values.
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
AttributeError: If the metric key is not a valid attribute of CompareColumn.
|
|
61
|
+
"""
|
|
62
|
+
for key, value in metrics.items():
|
|
63
|
+
if value is None:
|
|
64
|
+
continue
|
|
65
|
+
if key not in self.__slots__:
|
|
66
|
+
logger.error(f"The key '{key}' is not a valid attribute of CompareColumn.")
|
|
67
|
+
continue
|
|
68
|
+
setattr(self, key, value)
|
|
69
|
+
|
|
44
70
|
def to_column_value(self, is_pass, message):
|
|
45
71
|
return [self.bench_type, self.npu_type, self.shape, self.cosine_sim, self.max_abs_err, self.rel_err_hundredth,
|
|
46
72
|
self.rel_err_thousandth, self.rel_err_ten_thousandth, self.error_rate, self.eb, self.rmse,
|
|
@@ -50,6 +76,16 @@ class CompareColumn:
|
|
|
50
76
|
|
|
51
77
|
|
|
52
78
|
class ApiPrecisionOutputColumn:
|
|
79
|
+
__slots__ = [
|
|
80
|
+
'api_name', 'small_value_err_ratio', 'small_value_err_status', 'rmse_ratio', 'rmse_status',
|
|
81
|
+
'max_rel_err_ratio', 'max_rel_err_status', 'mean_rel_err_ratio', 'mean_rel_err_status', 'eb_ratio',
|
|
82
|
+
'eb_status', 'inf_nan_error_ratio', 'inf_nan_error_ratio_status', 'rel_err_ratio',
|
|
83
|
+
'rel_err_ratio_status', 'abs_err_ratio', 'abs_err_ratio_status', 'error_rate', 'error_rate_status',
|
|
84
|
+
'mean_ulp_err', 'ulp_err_proportion', 'ulp_err_proportion_ratio', 'ulp_err_status',
|
|
85
|
+
'rel_err_thousandth', 'rel_err_thousandth_status', 'compare_result', 'compare_algorithm',
|
|
86
|
+
'compare_message'
|
|
87
|
+
]
|
|
88
|
+
|
|
53
89
|
def __init__(self):
|
|
54
90
|
self.api_name = CompareConst.SPACE
|
|
55
91
|
self.small_value_err_ratio = CompareConst.SPACE
|
|
@@ -80,6 +116,24 @@ class ApiPrecisionOutputColumn:
|
|
|
80
116
|
self.compare_algorithm = CompareConst.SPACE
|
|
81
117
|
self.compare_message = CompareConst.SPACE
|
|
82
118
|
|
|
119
|
+
def update(self, metrics):
|
|
120
|
+
"""
|
|
121
|
+
Updates the object's attributes with the provided metrics.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
metrics (dict): A dictionary containing attribute names and their corresponding values.
|
|
125
|
+
|
|
126
|
+
Raises:
|
|
127
|
+
AttributeError: If the metric key is not a valid attribute of CompareColumn.
|
|
128
|
+
"""
|
|
129
|
+
for key, value in metrics.items():
|
|
130
|
+
if value is None:
|
|
131
|
+
continue
|
|
132
|
+
if key not in self.__slots__:
|
|
133
|
+
logger.error("The key '%s' is not a valid attribute of CompareColumn.", key)
|
|
134
|
+
continue
|
|
135
|
+
setattr(self, key, value)
|
|
136
|
+
|
|
83
137
|
def to_column_value(self):
|
|
84
138
|
return [self.api_name, self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
|
|
85
139
|
self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CompareInput:
|
|
22
|
+
"""
|
|
23
|
+
A class to encapsulate the input data required for comparison operations.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
bench_output (np.ndarray): The benchmark output values.
|
|
27
|
+
device_output (np.ndarray): The device output values.
|
|
28
|
+
compare_column (class): A clasee to store and update comparison metrics.
|
|
29
|
+
dtype (type, optional): The data type of the outputs. Defaults to None.
|
|
30
|
+
rel_err_orign (float or array-like, optional): The original relative error values. Defaults to None.
|
|
31
|
+
|
|
32
|
+
Methods:
|
|
33
|
+
__init__(bench_output, device_output, compare_column, dtype, rel_err_orign):
|
|
34
|
+
Initializes an instance of CompareInput.
|
|
35
|
+
"""
|
|
36
|
+
def __init__(self, bench_output, device_output, compare_column, dtype=None, rel_err_orign=None):
|
|
37
|
+
self.bench_output = bench_output
|
|
38
|
+
self.device_output = device_output
|
|
39
|
+
if not isinstance(bench_output, np.ndarray) or not isinstance(device_output, np.ndarray):
|
|
40
|
+
raise TypeError("The input should be numpy array")
|
|
41
|
+
self.compare_column = compare_column
|
|
42
|
+
self.dtype = dtype
|
|
43
|
+
self.rel_err_orign = rel_err_orign
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class PrecisionCompareInput:
|
|
47
|
+
def __init__(self, row_npu, row_gpu, dtype, compare_column):
|
|
48
|
+
self.row_npu = row_npu
|
|
49
|
+
self.row_gpu = row_gpu
|
|
50
|
+
self.dtype = dtype
|
|
51
|
+
self.compare_column = compare_column
|
|
@@ -43,10 +43,7 @@ absolute_standard_api = apis.get('AbsoluteThreshStandard')
|
|
|
43
43
|
binary_standard_api = apis.get('BinaryCompareStandard')
|
|
44
44
|
ulp_standard_api = apis.get('ULPStandard')
|
|
45
45
|
thousandth_standard_api = apis.get('ThousandthStandard')
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
threshold_yaml_path = os.path.join(cur_path, "api_precision_threshold.yaml")
|
|
49
|
-
apis_threshold = load_yaml(threshold_yaml_path)
|
|
46
|
+
accumulative_error_standard_api = apis.get('AccumulativeErrorStandard')
|
|
50
47
|
|
|
51
48
|
|
|
52
49
|
DETAIL_TEST_ROWS = [
|
|
@@ -134,6 +131,7 @@ ULP_PARAMETERS = {
|
|
|
134
131
|
class ApiPrecisionCompareColumn:
|
|
135
132
|
API_NAME = 'API Name'
|
|
136
133
|
DEVICE_DTYPE = 'DEVICE Dtype'
|
|
134
|
+
SHAPE = 'Shape'
|
|
137
135
|
SMALL_VALUE_ERROR_RATE = '小值域错误占比'
|
|
138
136
|
RMSE = '均方根误差'
|
|
139
137
|
MAX_REL_ERR = '相对误差最大值'
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
2
|
# -*- coding: utf-8 -*-
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
#
|
|
3
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
7
|
# you may not use this file except in compliance with the License.
|
|
7
8
|
# You may obtain a copy of the License at
|
|
8
9
|
#
|
|
@@ -13,17 +14,18 @@
|
|
|
13
14
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
15
|
# See the License for the specific language governing permissions and
|
|
15
16
|
# limitations under the License.
|
|
16
|
-
|
|
17
|
+
|
|
17
18
|
import argparse
|
|
18
19
|
import json
|
|
19
20
|
import os
|
|
20
21
|
import re
|
|
22
|
+
|
|
21
23
|
import math
|
|
22
24
|
import numpy as np
|
|
23
25
|
import torch
|
|
24
26
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import binary_standard_api, absolute_standard_api, \
|
|
28
|
+
ulp_standard_api, thousandth_standard_api
|
|
27
29
|
from msprobe.core.common.file_utils import FileOpen, load_json, save_json
|
|
28
30
|
from msprobe.core.common.utils import check_file_or_directory_path, check_op_str_pattern_valid, is_int
|
|
29
31
|
from msprobe.core.common.const import Const, MonitorConst, MsgConst
|
|
@@ -78,6 +80,7 @@ class APIInfo:
|
|
|
78
80
|
def is_supported_type(self):
|
|
79
81
|
return self.api_type in OPERATOR_TYPE
|
|
80
82
|
|
|
83
|
+
|
|
81
84
|
class CommonConfig:
|
|
82
85
|
def __init__(self, json_config):
|
|
83
86
|
self.dump_json_path = json_config.get('dump_json_path')
|
|
@@ -147,6 +150,7 @@ class CommonConfig:
|
|
|
147
150
|
if not is_int(self.iter_times):
|
|
148
151
|
raise ValueError(f'iter_times is invalid, it should be an int')
|
|
149
152
|
|
|
153
|
+
|
|
150
154
|
class APIExtractor:
|
|
151
155
|
def __init__(self, api_name, dump_json_path, output_file):
|
|
152
156
|
self.api_name = api_name
|
|
@@ -186,6 +190,7 @@ class APIExtractor:
|
|
|
186
190
|
elif DATA_NAME in data:
|
|
187
191
|
data[DATA_NAME] = os.path.join(dump_data_dir, data[DATA_NAME])
|
|
188
192
|
|
|
193
|
+
|
|
189
194
|
class OperatorScriptGenerator:
|
|
190
195
|
def __init__(self, common_config, args_info_forward, kwargs_info_forward, args_info_backward):
|
|
191
196
|
self.common_config = common_config
|
|
@@ -238,7 +243,8 @@ class OperatorScriptGenerator:
|
|
|
238
243
|
ordinal_number: how many times the same api has been called
|
|
239
244
|
direction_status: forward
|
|
240
245
|
random_seed: if mode is random_data, random seed is random_seed
|
|
241
|
-
iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data,
|
|
246
|
+
iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data,
|
|
247
|
+
iter_times does not matter
|
|
242
248
|
args_element_assignment: code for args assignment
|
|
243
249
|
args_list_generator_device: code for generate args list on device
|
|
244
250
|
args_list_generator_bench: code for generate args list on bench
|
|
@@ -267,17 +273,25 @@ class OperatorScriptGenerator:
|
|
|
267
273
|
internal_settings["iter_times"] = 1
|
|
268
274
|
else:
|
|
269
275
|
internal_settings["iter_times"] = self.common_config.iter_times
|
|
270
|
-
internal_settings["args_element_assignment"] =
|
|
271
|
-
|
|
272
|
-
internal_settings["
|
|
273
|
-
|
|
274
|
-
internal_settings["
|
|
275
|
-
|
|
276
|
+
internal_settings["args_element_assignment"] = \
|
|
277
|
+
self.generate_args_element_assignment_code(self.args_info_forward)
|
|
278
|
+
internal_settings["args_list_generator_device"] = \
|
|
279
|
+
self.generate_args_list(self.args_info_forward, flag_device=True)
|
|
280
|
+
internal_settings["args_list_generator_bench"] = \
|
|
281
|
+
self.generate_args_list(self.args_info_forward, flag_device=False)
|
|
282
|
+
internal_settings["kwargs_value_assignment"] = \
|
|
283
|
+
self.generate_kwargs_value_assignment_code(self.kwargs_info_forward)
|
|
284
|
+
internal_settings["kwargs_dict_generator_device"] = \
|
|
285
|
+
self.generate_kwargs_dict(self.kwargs_info_forward, flag_device=True)
|
|
286
|
+
internal_settings["kwargs_dict_generator_bench"] = \
|
|
287
|
+
self.generate_kwargs_dict(self.kwargs_info_forward, flag_device=False)
|
|
276
288
|
if self.common_config.propagation == Const.BACKWARD:
|
|
277
289
|
internal_settings["args_element_assignment_backward"] = self.generate_args_element_assignment_code(
|
|
278
290
|
self.args_info_backward)
|
|
279
|
-
internal_settings["args_list_generator_device_backward"] =
|
|
280
|
-
|
|
291
|
+
internal_settings["args_list_generator_device_backward"] = \
|
|
292
|
+
self.generate_args_list(self.args_info_backward, flag_device=True)
|
|
293
|
+
internal_settings["args_list_generator_bench_backward"] = \
|
|
294
|
+
self.generate_args_list(self.args_info_backward, flag_device=False)
|
|
281
295
|
else:
|
|
282
296
|
internal_settings["args_element_assignment_backward"] = ''
|
|
283
297
|
internal_settings["args_list_generator_device_backward"] = ''
|
|
@@ -290,12 +304,15 @@ class OperatorScriptGenerator:
|
|
|
290
304
|
args_element_assignment = ""
|
|
291
305
|
for index, arg in enumerate(args_info):
|
|
292
306
|
if isinstance(arg, (list, tuple)):
|
|
293
|
-
new_args_element_assignment =
|
|
307
|
+
new_args_element_assignment = \
|
|
308
|
+
self.recursive_args_element_assignment(arg, name_number + "_" + str(index))
|
|
294
309
|
args_element_assignment += new_args_element_assignment
|
|
295
310
|
else:
|
|
296
311
|
arg["parameter_name"] = "arg" + name_number + "_" + str(index)
|
|
297
|
-
args_element_assignment += " " + "arg_info" + name_number + "_" + str(index) + " = " +
|
|
298
|
-
|
|
312
|
+
args_element_assignment += " " + "arg_info" + name_number + "_" + str(index) + " = " + \
|
|
313
|
+
"{}".format(str(arg)) + MsgConst.SPECIAL_CHAR[0]
|
|
314
|
+
args_element_assignment += " " + "arg" + name_number + "_" + str(index) + " = " + \
|
|
315
|
+
"generate_data(arg_info" + name_number + "_" + str(index) + ")" + MsgConst.SPECIAL_CHAR[0]
|
|
299
316
|
return args_element_assignment
|
|
300
317
|
|
|
301
318
|
|
|
@@ -320,7 +337,8 @@ class OperatorScriptGenerator:
|
|
|
320
337
|
args_list_generator += ".to(device)"
|
|
321
338
|
if flag_bench:
|
|
322
339
|
args_list_generator += '.to(torch.device("cpu"))'
|
|
323
|
-
args_list_generator += ".to(RAISE_PRECISION.get(str(" + arg.get("parameter_name") +
|
|
340
|
+
args_list_generator += ".to(RAISE_PRECISION.get(str(" + arg.get("parameter_name") + \
|
|
341
|
+
".dtype), " + arg.get("parameter_name") + ".dtype))"
|
|
324
342
|
args_list_generator += Const.COMMA
|
|
325
343
|
return args_list_generator
|
|
326
344
|
|
|
@@ -338,12 +356,15 @@ class OperatorScriptGenerator:
|
|
|
338
356
|
if info.get("type") == "torch.device" or info.get("type") == "torch.dtype":
|
|
339
357
|
kwargs_value_assignment += " " + "kwarg_" + key_name + name_number + " = " + info.get("value")
|
|
340
358
|
else:
|
|
341
|
-
kwargs_value_assignment += " " + "kwarg_info_" + key_name + name_number + " = " +
|
|
342
|
-
|
|
359
|
+
kwargs_value_assignment += " " + "kwarg_info_" + key_name + name_number + " = " + \
|
|
360
|
+
"{}".format(str(info)) + MsgConst.SPECIAL_CHAR[0]
|
|
361
|
+
kwargs_value_assignment += " " + "kwarg_" + key_name + name_number + " = " + \
|
|
362
|
+
"generate_data(kwarg_info_" + key_name + name_number + ")" + MsgConst.SPECIAL_CHAR[0]
|
|
343
363
|
info["parameter_name"] = "kwarg_" + key_name + name_number
|
|
344
364
|
else:
|
|
345
365
|
for index, arg in enumerate(info):
|
|
346
|
-
new_kwargs_value_assignment = self.recursive_kwargs_value_assignment(arg, key_name, name_number +
|
|
366
|
+
new_kwargs_value_assignment = self.recursive_kwargs_value_assignment(arg, key_name, name_number + \
|
|
367
|
+
"_" + str(index))
|
|
347
368
|
kwargs_value_assignment += new_kwargs_value_assignment
|
|
348
369
|
return kwargs_value_assignment
|
|
349
370
|
|
|
@@ -363,7 +384,8 @@ class OperatorScriptGenerator:
|
|
|
363
384
|
kwargs_dict_generator += ".to(device)"
|
|
364
385
|
if flag_bench:
|
|
365
386
|
kwargs_dict_generator += '.to(torch.device("cpu"))'
|
|
366
|
-
kwargs_dict_generator += ".to(RAISE_PRECISION.get(str(" + info.get("parameter_name") +
|
|
387
|
+
kwargs_dict_generator += ".to(RAISE_PRECISION.get(str(" + info.get("parameter_name") + \
|
|
388
|
+
".dtype), " + info.get("parameter_name") + ".dtype))"
|
|
367
389
|
else:
|
|
368
390
|
(left_bracket, right_bracket) = ("[", "]") if isinstance(info, list) else ("(", ")")
|
|
369
391
|
kwargs_dict_generator += left_bracket
|
|
@@ -386,13 +408,14 @@ class OperatorScriptGenerator:
|
|
|
386
408
|
|
|
387
409
|
|
|
388
410
|
|
|
389
|
-
def
|
|
411
|
+
def _op_generator_parser(parser):
|
|
390
412
|
parser.add_argument("-i", "--config_input", dest="config_input", default='', type=str,
|
|
391
413
|
help="<Optional> Path of config json file", required=True)
|
|
392
414
|
parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str,
|
|
393
415
|
help="<Required> Path of extract api_name.json.",
|
|
394
416
|
required=True)
|
|
395
417
|
|
|
418
|
+
|
|
396
419
|
def parse_json_config(json_file_path):
|
|
397
420
|
if not json_file_path:
|
|
398
421
|
config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
@@ -401,11 +424,8 @@ def parse_json_config(json_file_path):
|
|
|
401
424
|
common_config = CommonConfig(json_config)
|
|
402
425
|
return common_config
|
|
403
426
|
|
|
404
|
-
def main():
|
|
405
|
-
parser = argparse.ArgumentParser()
|
|
406
|
-
op_generator_parser(parser)
|
|
407
|
-
cmd_args = parser.parse_args()
|
|
408
427
|
|
|
428
|
+
def _run_operator_generate_commond(cmd_args):
|
|
409
429
|
common_config = parse_json_config(cmd_args.config_input)
|
|
410
430
|
|
|
411
431
|
if common_config.dump_json_path:
|
|
@@ -438,7 +458,8 @@ def main():
|
|
|
438
458
|
internal_settings = op_generate.get_settings(api_full_name_forward)
|
|
439
459
|
|
|
440
460
|
template_path = os.path.join(os.path.dirname(__file__), "operator_replication.template")
|
|
441
|
-
operator_script_path = os.path.join(cmd_args.api_output_path,
|
|
461
|
+
operator_script_path = os.path.join(cmd_args.api_output_path,
|
|
462
|
+
"{0}.py".format(internal_settings.get("api_full_name")))
|
|
442
463
|
|
|
443
464
|
try:
|
|
444
465
|
with FileOpen(template_path, 'r') as ftemp, FileOpen(operator_script_path, 'w') as fout:
|
|
@@ -451,4 +472,7 @@ def main():
|
|
|
451
472
|
|
|
452
473
|
|
|
453
474
|
if __name__ == "__main__":
|
|
454
|
-
|
|
475
|
+
parser = argparse.ArgumentParser()
|
|
476
|
+
_op_generator_parser(parser)
|
|
477
|
+
cmd_args = parser.parse_args()
|
|
478
|
+
_run_operator_generate_commond(cmd_args)
|