mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -40,7 +40,7 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validat
|
|
|
40
40
|
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments, extract_basic_api_segments
|
|
41
41
|
from msprobe.core.common.file_utils import FileChecker, change_mode, create_directory
|
|
42
42
|
from msprobe.pytorch.common.log import logger
|
|
43
|
-
from msprobe.core.common.utils import CompareException
|
|
43
|
+
from msprobe.core.common.utils import CompareException, check_op_str_pattern_valid
|
|
44
44
|
from msprobe.core.common.const import Const, CompareConst, FileCheckConst
|
|
45
45
|
|
|
46
46
|
CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
|
|
@@ -151,6 +151,7 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
151
151
|
message = ''
|
|
152
152
|
compare_column = ApiPrecisionOutputColumn()
|
|
153
153
|
full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
|
|
154
|
+
check_op_str_pattern_valid(full_api_name_with_direction_status)
|
|
154
155
|
row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status]
|
|
155
156
|
api_name, api_full_name, direction_status = extract_detailed_api_segments(full_api_name_with_direction_status)
|
|
156
157
|
if not api_full_name:
|
|
@@ -430,6 +431,7 @@ def _api_precision_compare(parser=None):
|
|
|
430
431
|
_api_precision_compare_parser(parser)
|
|
431
432
|
args = parser.parse_args(sys.argv[1:])
|
|
432
433
|
_api_precision_compare_command(args)
|
|
434
|
+
logger.info("Compare task completed.")
|
|
433
435
|
|
|
434
436
|
|
|
435
437
|
def _api_precision_compare_command(args):
|
|
@@ -457,8 +459,3 @@ def _api_precision_compare_parser(parser):
|
|
|
457
459
|
parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
|
|
458
460
|
help="<optional> The api precision compare task result out path.",
|
|
459
461
|
required=False)
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
if __name__ == '__main__':
|
|
463
|
-
_api_precision_compare()
|
|
464
|
-
logger.info("Compare task completed.")
|
|
@@ -40,6 +40,7 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dty
|
|
|
40
40
|
DETAIL_TEST_ROWS, BENCHMARK_COMPARE_SUPPORT_LIST
|
|
41
41
|
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
|
|
42
42
|
from msprobe.pytorch.common.log import logger
|
|
43
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
43
44
|
|
|
44
45
|
|
|
45
46
|
ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status',
|
|
@@ -178,6 +179,41 @@ class Comparator:
|
|
|
178
179
|
if not os.path.exists(detail_save_path):
|
|
179
180
|
write_csv(DETAIL_TEST_ROWS, detail_save_path)
|
|
180
181
|
|
|
182
|
+
@recursion_depth_decorator("compare_core")
|
|
183
|
+
def _compare_core(self, api_name, bench_output, device_output):
|
|
184
|
+
compare_column = CompareColumn()
|
|
185
|
+
if not isinstance(bench_output, type(device_output)):
|
|
186
|
+
status = CompareConst.ERROR
|
|
187
|
+
message = "bench and npu output type is different."
|
|
188
|
+
elif isinstance(bench_output, dict):
|
|
189
|
+
b_keys, n_keys = set(bench_output.keys()), set(device_output.keys())
|
|
190
|
+
if b_keys != n_keys:
|
|
191
|
+
status = CompareConst.ERROR
|
|
192
|
+
message = "bench and npu output dict keys are different."
|
|
193
|
+
else:
|
|
194
|
+
status, compare_column, message = self._compare_core(api_name, list(bench_output.values()),
|
|
195
|
+
list(device_output.values()))
|
|
196
|
+
elif isinstance(bench_output, torch.Tensor):
|
|
197
|
+
copy_bench_out = bench_output.detach().clone()
|
|
198
|
+
copy_device_output = device_output.detach().clone()
|
|
199
|
+
compare_column.bench_type = str(copy_bench_out.dtype)
|
|
200
|
+
compare_column.npu_type = str(copy_device_output.dtype)
|
|
201
|
+
compare_column.shape = tuple(device_output.shape)
|
|
202
|
+
status, compare_column, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output,
|
|
203
|
+
compare_column)
|
|
204
|
+
elif isinstance(bench_output, (bool, int, float, str)):
|
|
205
|
+
compare_column.bench_type = str(type(bench_output))
|
|
206
|
+
compare_column.npu_type = str(type(device_output))
|
|
207
|
+
status, compare_column, message = self._compare_builtin_type(bench_output, device_output, compare_column)
|
|
208
|
+
elif bench_output is None:
|
|
209
|
+
status = CompareConst.SKIP
|
|
210
|
+
message = "Bench output is None, skip this test."
|
|
211
|
+
else:
|
|
212
|
+
status = CompareConst.ERROR
|
|
213
|
+
message = "Unexpected output type in compare_core: {}".format(type(bench_output))
|
|
214
|
+
|
|
215
|
+
return status, compare_column, message
|
|
216
|
+
|
|
181
217
|
def write_summary_csv(self, test_result):
|
|
182
218
|
test_rows = []
|
|
183
219
|
try:
|
|
@@ -293,40 +329,6 @@ class Comparator:
|
|
|
293
329
|
test_final_success = CompareConst.WARNING
|
|
294
330
|
return test_final_success, detailed_result_total
|
|
295
331
|
|
|
296
|
-
def _compare_core(self, api_name, bench_output, device_output):
|
|
297
|
-
compare_column = CompareColumn()
|
|
298
|
-
if not isinstance(bench_output, type(device_output)):
|
|
299
|
-
status = CompareConst.ERROR
|
|
300
|
-
message = "bench and npu output type is different."
|
|
301
|
-
elif isinstance(bench_output, dict):
|
|
302
|
-
b_keys, n_keys = set(bench_output.keys()), set(device_output.keys())
|
|
303
|
-
if b_keys != n_keys:
|
|
304
|
-
status = CompareConst.ERROR
|
|
305
|
-
message = "bench and npu output dict keys are different."
|
|
306
|
-
else:
|
|
307
|
-
status, compare_column, message = self._compare_core(api_name, list(bench_output.values()),
|
|
308
|
-
list(device_output.values()))
|
|
309
|
-
elif isinstance(bench_output, torch.Tensor):
|
|
310
|
-
copy_bench_out = bench_output.detach().clone()
|
|
311
|
-
copy_device_output = device_output.detach().clone()
|
|
312
|
-
compare_column.bench_type = str(copy_bench_out.dtype)
|
|
313
|
-
compare_column.npu_type = str(copy_device_output.dtype)
|
|
314
|
-
compare_column.shape = tuple(device_output.shape)
|
|
315
|
-
status, compare_column, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output,
|
|
316
|
-
compare_column)
|
|
317
|
-
elif isinstance(bench_output, (bool, int, float, str)):
|
|
318
|
-
compare_column.bench_type = str(type(bench_output))
|
|
319
|
-
compare_column.npu_type = str(type(device_output))
|
|
320
|
-
status, compare_column, message = self._compare_builtin_type(bench_output, device_output, compare_column)
|
|
321
|
-
elif bench_output is None:
|
|
322
|
-
status = CompareConst.SKIP
|
|
323
|
-
message = "Bench output is None, skip this test."
|
|
324
|
-
else:
|
|
325
|
-
status = CompareConst.ERROR
|
|
326
|
-
message = "Unexpected output type in compare_core: {}".format(type(bench_output))
|
|
327
|
-
|
|
328
|
-
return status, compare_column, message
|
|
329
|
-
|
|
330
332
|
def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column):
|
|
331
333
|
cpu_shape = bench_output.shape
|
|
332
334
|
npu_shape = device_output.shape
|
|
@@ -28,10 +28,10 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import binary_st
|
|
|
28
28
|
ulp_standard_api, thousandth_standard_api
|
|
29
29
|
from msprobe.core.common.file_utils import FileOpen, load_json, save_json
|
|
30
30
|
from msprobe.core.common.utils import check_file_or_directory_path, check_op_str_pattern_valid, is_int
|
|
31
|
-
from msprobe.core.common.const import Const, MonitorConst, MsgConst
|
|
31
|
+
from msprobe.core.common.const import Const, MonitorConst, MsgConst, FileCheckConst
|
|
32
32
|
from msprobe.core.common.log import logger
|
|
33
|
-
from msprobe.core.common.file_utils import make_dir
|
|
34
|
-
from msprobe.core.common.
|
|
33
|
+
from msprobe.core.common.file_utils import make_dir, change_mode
|
|
34
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
35
35
|
|
|
36
36
|
TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
|
|
37
37
|
TORCH_BOOL_TYPE = ["torch.bool"]
|
|
@@ -50,6 +50,7 @@ DATA_NAME = "data_name"
|
|
|
50
50
|
API_MAX_LENGTH = 30
|
|
51
51
|
PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD]
|
|
52
52
|
DATAMODE_LIST = ["random_data", "real_data"]
|
|
53
|
+
ITER_MAX_TIMES = 1000
|
|
53
54
|
|
|
54
55
|
|
|
55
56
|
class APIInfo:
|
|
@@ -97,6 +98,8 @@ class CommonConfig:
|
|
|
97
98
|
iter_t = self.iter_times
|
|
98
99
|
if iter_t <= 0:
|
|
99
100
|
raise ValueError("iter_times should be an integer bigger than zero!")
|
|
101
|
+
if iter_t > ITER_MAX_TIMES:
|
|
102
|
+
raise ValueError("iter_times should not be greater than 1000!")
|
|
100
103
|
|
|
101
104
|
json_file = self.extract_api_path
|
|
102
105
|
propagation = self.propagation
|
|
@@ -117,7 +120,7 @@ class CommonConfig:
|
|
|
117
120
|
|
|
118
121
|
# Retrieve the first API name and dictionary
|
|
119
122
|
forward_item = next(iter(json_content.items()), None)
|
|
120
|
-
if not forward_item or not isinstance(forward_item[1], dict):
|
|
123
|
+
if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]:
|
|
121
124
|
raise ValueError(f'Invalid forward API data in json_content!')
|
|
122
125
|
|
|
123
126
|
# if propagation is backward, ensure json file contains forward and backward info
|
|
@@ -127,7 +130,7 @@ class CommonConfig:
|
|
|
127
130
|
# if propagation is backward, ensure it has valid data
|
|
128
131
|
if propagation == Const.BACKWARD:
|
|
129
132
|
backward_item = list(json_content.items())[1]
|
|
130
|
-
if not isinstance(backward_item[1], dict):
|
|
133
|
+
if not isinstance(backward_item[1], dict) or not backward_item[1]:
|
|
131
134
|
raise ValueError(f'Invalid backward API data in json_content!')
|
|
132
135
|
|
|
133
136
|
return json_content
|
|
@@ -169,7 +172,7 @@ class APIExtractor:
|
|
|
169
172
|
value = self.load_real_data_path(value, real_data_path)
|
|
170
173
|
new_data[key] = value
|
|
171
174
|
if not new_data:
|
|
172
|
-
logger.
|
|
175
|
+
logger.warning(f"Warning: The api '{self.api_name}' does not exist in the file.")
|
|
173
176
|
else:
|
|
174
177
|
save_json(self.output_file, new_data, indent=4)
|
|
175
178
|
logger.info(
|
|
@@ -183,6 +186,7 @@ class APIExtractor:
|
|
|
183
186
|
self.update_data_name(v, dump_data_dir)
|
|
184
187
|
return value
|
|
185
188
|
|
|
189
|
+
@recursion_depth_decorator("OpGenerator: APIExtractor.update_data_name")
|
|
186
190
|
def update_data_name(self, data, dump_data_dir):
|
|
187
191
|
if isinstance(data, list):
|
|
188
192
|
for item in data:
|
|
@@ -407,19 +411,16 @@ class OperatorScriptGenerator:
|
|
|
407
411
|
return kwargs_dict_generator
|
|
408
412
|
|
|
409
413
|
|
|
410
|
-
|
|
411
414
|
def _op_generator_parser(parser):
|
|
412
|
-
parser.add_argument("-i", "--config_input", dest="config_input",
|
|
413
|
-
help="<
|
|
415
|
+
parser.add_argument("-i", "--config_input", dest="config_input", type=str,
|
|
416
|
+
help="<Required> Path of config json file", required=True)
|
|
414
417
|
parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str,
|
|
415
|
-
help="<Required> Path of extract api_name.json.",
|
|
416
|
-
required=True)
|
|
418
|
+
help="<Required> Path of extract api_name.json.", required=True)
|
|
417
419
|
|
|
418
420
|
|
|
419
421
|
def parse_json_config(json_file_path):
|
|
420
422
|
if not json_file_path:
|
|
421
|
-
|
|
422
|
-
json_file_path = os.path.join(config_dir, "config.json")
|
|
423
|
+
raise Exception("config_input path can not be empty, please check.")
|
|
423
424
|
json_config = load_json(json_file_path)
|
|
424
425
|
common_config = CommonConfig(json_config)
|
|
425
426
|
return common_config
|
|
@@ -467,6 +468,7 @@ def _run_operator_generate_commond(cmd_args):
|
|
|
467
468
|
fout.write(code_template.format(**internal_settings))
|
|
468
469
|
except OSError:
|
|
469
470
|
logger.error(f"Failed to open file. Please check file {template_path} or {operator_script_path}.")
|
|
471
|
+
change_mode(operator_script_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
470
472
|
|
|
471
473
|
logger.info(f"Generate operator script successfully and the name is {operator_script_path}.")
|
|
472
474
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import json
|
|
2
1
|
import os
|
|
3
|
-
import
|
|
2
|
+
import re
|
|
3
|
+
import stat
|
|
4
4
|
from enum import Enum, auto
|
|
5
5
|
import torch
|
|
6
6
|
try:
|
|
@@ -25,6 +25,31 @@ RAISE_PRECISION = {{
|
|
|
25
25
|
}}
|
|
26
26
|
THOUSANDTH_THRESHOLDING = 0.001
|
|
27
27
|
BACKWARD = 'backward'
|
|
28
|
+
DIR = "dir"
|
|
29
|
+
FILE = "file"
|
|
30
|
+
READ_ABLE = "read"
|
|
31
|
+
WRITE_ABLE = "write"
|
|
32
|
+
READ_WRITE_ABLE = "read and write"
|
|
33
|
+
DIRECTORY_LENGTH = 4096
|
|
34
|
+
FILE_NAME_LENGTH = 255
|
|
35
|
+
SOFT_LINK_ERROR = "检测到软链接"
|
|
36
|
+
FILE_PERMISSION_ERROR = "文件权限错误"
|
|
37
|
+
INVALID_FILE_ERROR = "无效文件"
|
|
38
|
+
ILLEGAL_PATH_ERROR = "非法文件路径"
|
|
39
|
+
ILLEGAL_PARAM_ERROR = "非法打开方式"
|
|
40
|
+
FILE_TOO_LARGE_ERROR = "文件过大"
|
|
41
|
+
FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$"
|
|
42
|
+
FILE_SIZE_DICT = {{
|
|
43
|
+
".pkl": 1073741824, # 1 * 1024 * 1024 * 1024
|
|
44
|
+
".npy": 10737418240, # 10 * 1024 * 1024 * 1024
|
|
45
|
+
".json": 1073741824, # 1 * 1024 * 1024 * 1024
|
|
46
|
+
".pt": 10737418240, # 10 * 1024 * 1024 * 1024
|
|
47
|
+
".csv": 1073741824, # 1 * 1024 * 1024 * 1024
|
|
48
|
+
".xlsx": 1073741824, # 1 * 1024 * 1024 * 1024
|
|
49
|
+
".yaml": 1073741824, # 1 * 1024 * 1024 * 1024
|
|
50
|
+
".ir": 1073741824 # 1 * 1024 * 1024 * 1024
|
|
51
|
+
}}
|
|
52
|
+
COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
|
|
28
53
|
|
|
29
54
|
class CompareStandard(Enum):
|
|
30
55
|
BINARY_EQUALITY_STANDARD = auto()
|
|
@@ -33,13 +58,189 @@ class CompareStandard(Enum):
|
|
|
33
58
|
BENCHMARK_STANDARD = auto()
|
|
34
59
|
THOUSANDTH_STANDARD = auto()
|
|
35
60
|
|
|
61
|
+
class FileChecker:
|
|
62
|
+
"""
|
|
63
|
+
The class for check file.
|
|
64
|
+
|
|
65
|
+
Attributes:
|
|
66
|
+
file_path: The file or dictionary path to be verified.
|
|
67
|
+
path_type: file or dictionary
|
|
68
|
+
ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability
|
|
69
|
+
file_type(str): The correct file type for file
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True):
|
|
73
|
+
self.file_path = file_path
|
|
74
|
+
self.path_type = self._check_path_type(path_type)
|
|
75
|
+
self.ability = ability
|
|
76
|
+
self.file_type = file_type
|
|
77
|
+
self.is_script = is_script
|
|
78
|
+
|
|
79
|
+
@staticmethod
|
|
80
|
+
def _check_path_type(path_type):
|
|
81
|
+
if path_type not in [DIR, FILE]:
|
|
82
|
+
print(f'ERROR: The path_type must be {{DIR}} or {{FILE}}.')
|
|
83
|
+
raise Exception(ILLEGAL_PARAM_ERROR)
|
|
84
|
+
return path_type
|
|
85
|
+
|
|
86
|
+
def common_check(self):
|
|
87
|
+
"""
|
|
88
|
+
功能:用户校验基本文件权限:软连接、文件长度、是否存在、读写权限、文件属组、文件特殊字符
|
|
89
|
+
注意:文件后缀的合法性,非通用操作,可使用其他独立接口实现
|
|
90
|
+
"""
|
|
91
|
+
FileChecker.check_path_exists(self.file_path)
|
|
92
|
+
FileChecker.check_link(self.file_path)
|
|
93
|
+
self.file_path = os.path.realpath(self.file_path)
|
|
94
|
+
FileChecker.check_path_length(self.file_path)
|
|
95
|
+
FileChecker.check_path_type(self.file_path, self.path_type)
|
|
96
|
+
self.check_path_ability()
|
|
97
|
+
if self.is_script:
|
|
98
|
+
FileChecker.check_path_owner_consistent(self.file_path)
|
|
99
|
+
FileChecker.check_path_pattern_valid(self.file_path)
|
|
100
|
+
FileChecker.check_common_file_size(self.file_path)
|
|
101
|
+
FileChecker.check_file_suffix(self.file_path, self.file_type)
|
|
102
|
+
if self.path_type == FILE:
|
|
103
|
+
FileChecker.check_dirpath_before_read(self.file_path)
|
|
104
|
+
return self.file_path
|
|
105
|
+
|
|
106
|
+
def check_path_ability(self):
|
|
107
|
+
if self.ability == WRITE_ABLE:
|
|
108
|
+
FileChecker.check_path_writability(self.file_path)
|
|
109
|
+
if self.ability == READ_ABLE:
|
|
110
|
+
FileChecker.check_path_readability(self.file_path)
|
|
111
|
+
if self.ability == READ_WRITE_ABLE:
|
|
112
|
+
FileChecker.check_path_readability(self.file_path)
|
|
113
|
+
FileChecker.check_path_writability(self.file_path)
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def check_path_exists(path):
|
|
117
|
+
if not os.path.exists(path):
|
|
118
|
+
print(f'ERROR: The file path %s does not exist.' % path)
|
|
119
|
+
raise Exception()
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def check_link(path):
|
|
123
|
+
abs_path = os.path.abspath(path)
|
|
124
|
+
if os.path.islink(abs_path):
|
|
125
|
+
print('ERROR: The file path {{}} is a soft link.'.format(path))
|
|
126
|
+
raise Exception(SOFT_LINK_ERROR)
|
|
127
|
+
|
|
128
|
+
@staticmethod
|
|
129
|
+
def check_path_length(path, name_length=None):
|
|
130
|
+
file_max_name_length = name_length if name_length else FILE_NAME_LENGTH
|
|
131
|
+
if len(path) > DIRECTORY_LENGTH or \
|
|
132
|
+
len(os.path.basename(path)) > file_max_name_length:
|
|
133
|
+
print(f'ERROR: The file path length exceeds limit.')
|
|
134
|
+
raise Exception(ILLEGAL_PATH_ERROR)
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def check_path_type(file_path, file_type):
|
|
138
|
+
if file_type == FILE:
|
|
139
|
+
if not os.path.isfile(file_path):
|
|
140
|
+
print(f"ERROR: The {{file_path}} should be a file!")
|
|
141
|
+
raise Exception(INVALID_FILE_ERROR)
|
|
142
|
+
if file_type == DIR:
|
|
143
|
+
if not os.path.isdir(file_path):
|
|
144
|
+
print(f"ERROR: The {{file_path}} should be a dictionary!")
|
|
145
|
+
raise Exception(INVALID_FILE_ERROR)
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def check_path_owner_consistent(path):
|
|
149
|
+
file_owner = os.stat(path).st_uid
|
|
150
|
+
if file_owner != os.getuid() and os.getuid() != 0:
|
|
151
|
+
print('ERROR: The file path %s may be insecure because is does not belong to you.' % path)
|
|
152
|
+
raise Exception(FILE_PERMISSION_ERROR)
|
|
153
|
+
|
|
154
|
+
@staticmethod
|
|
155
|
+
def check_path_pattern_valid(path):
|
|
156
|
+
if not re.match(FILE_VALID_PATTERN, path):
|
|
157
|
+
print('ERROR: The file path %s contains special characters.' % (path))
|
|
158
|
+
raise Exception(ILLEGAL_PATH_ERROR)
|
|
159
|
+
|
|
160
|
+
@staticmethod
|
|
161
|
+
def check_common_file_size(file_path):
|
|
162
|
+
if os.path.isfile(file_path):
|
|
163
|
+
for suffix, max_size in FILE_SIZE_DICT.items():
|
|
164
|
+
if file_path.endswith(suffix):
|
|
165
|
+
FileChecker.check_file_size(file_path, max_size)
|
|
166
|
+
return
|
|
167
|
+
FileChecker.check_file_size(file_path, COMMOM_FILE_SIZE)
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def check_file_size(file_path, max_size):
|
|
171
|
+
try:
|
|
172
|
+
file_size = os.path.getsize(file_path)
|
|
173
|
+
except OSError as os_error:
|
|
174
|
+
print(f'ERROR: Failed to open "{{file_path}}". {{str(os_error)}}')
|
|
175
|
+
raise Exception(INVALID_FILE_ERROR) from os_error
|
|
176
|
+
if file_size >= max_size:
|
|
177
|
+
print(f'ERROR: The size ({{file_size}}) of {{file_path}} exceeds ({{max_size}}) bytes, tools not support.')
|
|
178
|
+
raise Exception(FILE_TOO_LARGE_ERROR)
|
|
179
|
+
|
|
180
|
+
@staticmethod
|
|
181
|
+
def check_file_suffix(file_path, file_suffix):
|
|
182
|
+
if file_suffix:
|
|
183
|
+
if not file_path.endswith(file_suffix):
|
|
184
|
+
print(f"The {{file_path}} should be a {{file_suffix}} file!")
|
|
185
|
+
raise Exception(INVALID_FILE_ERROR)
|
|
186
|
+
|
|
187
|
+
@staticmethod
|
|
188
|
+
def check_dirpath_before_read(path):
|
|
189
|
+
path = os.path.realpath(path)
|
|
190
|
+
dirpath = os.path.dirname(path)
|
|
191
|
+
if FileChecker.check_others_writable(dirpath):
|
|
192
|
+
print(f"WARNING: The directory is writable by others: {{dirpath}}.")
|
|
193
|
+
try:
|
|
194
|
+
FileChecker.check_path_owner_consistent(dirpath)
|
|
195
|
+
except Exception:
|
|
196
|
+
print(f"WARNING: The directory {{dirpath}} is not yours.")
|
|
197
|
+
|
|
198
|
+
@staticmethod
|
|
199
|
+
def check_others_writable(directory):
|
|
200
|
+
dir_stat = os.stat(directory)
|
|
201
|
+
is_writable = (
|
|
202
|
+
bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写
|
|
203
|
+
bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写
|
|
204
|
+
)
|
|
205
|
+
return is_writable
|
|
206
|
+
|
|
207
|
+
@staticmethod
|
|
208
|
+
def check_path_readability(path):
|
|
209
|
+
if not os.access(path, os.R_OK):
|
|
210
|
+
print('ERROR: The file path %s is not readable.' % path)
|
|
211
|
+
raise Exception(FILE_PERMISSION_ERROR)
|
|
212
|
+
|
|
213
|
+
@staticmethod
|
|
214
|
+
def check_path_writability(path):
|
|
215
|
+
if not os.access(path, os.W_OK):
|
|
216
|
+
print('ERROR: The file path %s is not writable.' % path)
|
|
217
|
+
raise Exception(FILE_PERMISSION_ERROR)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def check_file_or_directory_path(path, isdir=False):
|
|
221
|
+
"""
|
|
222
|
+
Function Description:
|
|
223
|
+
check whether the path is valid
|
|
224
|
+
Parameter:
|
|
225
|
+
path: the path to check
|
|
226
|
+
isdir: the path is dir or file
|
|
227
|
+
Exception Description:
|
|
228
|
+
when invalid data throw exception
|
|
229
|
+
"""
|
|
230
|
+
if isdir:
|
|
231
|
+
path_checker = FileChecker(path, DIR, WRITE_ABLE)
|
|
232
|
+
else:
|
|
233
|
+
path_checker = FileChecker(path, FILE, READ_ABLE)
|
|
234
|
+
path_checker.common_check()
|
|
235
|
+
|
|
36
236
|
def load_pt(pt_path, to_cpu=False):
|
|
37
237
|
pt_path = os.path.realpath(pt_path)
|
|
238
|
+
check_file_or_directory_path(pt_path)
|
|
38
239
|
try:
|
|
39
240
|
if to_cpu:
|
|
40
|
-
pt = torch.load(pt_path, map_location=torch.device("cpu"))
|
|
241
|
+
pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True)
|
|
41
242
|
else:
|
|
42
|
-
pt = torch.load(pt_path)
|
|
243
|
+
pt = torch.load(pt_path, weights_only=True)
|
|
43
244
|
except Exception as e:
|
|
44
245
|
raise RuntimeError(f"load pt file {{pt_path}} failed") from e
|
|
45
246
|
return pt
|
|
@@ -202,6 +403,7 @@ def compare_tensor(out_device, out_bench, api_name):
|
|
|
202
403
|
else:
|
|
203
404
|
abs_err = torch.abs(out_device - out_bench)
|
|
204
405
|
abs_bench = torch.abs(out_bench)
|
|
406
|
+
eps = 2 ** -23
|
|
205
407
|
if dtype_bench == torch.float32:
|
|
206
408
|
eps = 2 ** -23
|
|
207
409
|
if dtype_bench == torch.float64:
|
|
@@ -50,6 +50,9 @@ def split_json_file(input_file, num_splits, filter_api):
|
|
|
50
50
|
backward_data[f"{data_name}.backward"] = backward_data.pop(data_name)
|
|
51
51
|
|
|
52
52
|
input_data = load_json(input_file)
|
|
53
|
+
if "dump_data_dir" not in input_data.keys():
|
|
54
|
+
logger.error("Invalid input file, 'dump_data_dir' field is missing")
|
|
55
|
+
raise CompareException("Invalid input file, 'dump_data_dir' field is missing")
|
|
53
56
|
if input_data.get("data") is None:
|
|
54
57
|
logger.error("Invalid input file, 'data' field is missing")
|
|
55
58
|
raise CompareException("Invalid input file, 'data' field is missing")
|
|
@@ -84,10 +87,6 @@ def signal_handler(signum, frame):
|
|
|
84
87
|
raise KeyboardInterrupt()
|
|
85
88
|
|
|
86
89
|
|
|
87
|
-
signal.signal(signal.SIGINT, signal_handler)
|
|
88
|
-
signal.signal(signal.SIGTERM, signal_handler)
|
|
89
|
-
|
|
90
|
-
|
|
91
90
|
ParallelUTConfig = namedtuple('ParallelUTConfig', ['api_files', 'out_path', 'num_splits',
|
|
92
91
|
'save_error_data_flag', 'jit_compile_flag', 'device_id',
|
|
93
92
|
'result_csv_path', 'total_items', 'config_path'])
|
|
@@ -97,7 +96,7 @@ def run_parallel_ut(config):
|
|
|
97
96
|
processes = []
|
|
98
97
|
device_id_cycle = cycle(config.device_id)
|
|
99
98
|
if config.save_error_data_flag:
|
|
100
|
-
logger.info("UT task error
|
|
99
|
+
logger.info("UT task error data will be saved")
|
|
101
100
|
logger.info(f"Starting parallel UT with {config.num_splits} processes")
|
|
102
101
|
progress_bar = tqdm(total=config.total_items, desc="Total items", unit="items")
|
|
103
102
|
|
|
@@ -129,6 +128,9 @@ def run_parallel_ut(config):
|
|
|
129
128
|
sys.stdout.flush()
|
|
130
129
|
except ValueError as e:
|
|
131
130
|
logger.warning(f"An error occurred while reading subprocess output: {e}")
|
|
131
|
+
finally:
|
|
132
|
+
if process.poll() is None:
|
|
133
|
+
process.stdout.close()
|
|
132
134
|
|
|
133
135
|
def update_progress_bar(progress_bar, result_csv_path):
|
|
134
136
|
while any(process.poll() is None for process in processes):
|
|
@@ -214,6 +216,8 @@ def prepare_config(args):
|
|
|
214
216
|
|
|
215
217
|
|
|
216
218
|
def main():
|
|
219
|
+
signal.signal(signal.SIGINT, signal_handler)
|
|
220
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
|
217
221
|
parser = argparse.ArgumentParser(description='Run UT in parallel')
|
|
218
222
|
_run_ut_parser(parser)
|
|
219
223
|
parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
|
|
@@ -221,7 +225,3 @@ def main():
|
|
|
221
225
|
args = parser.parse_args()
|
|
222
226
|
config = prepare_config(args)
|
|
223
227
|
run_parallel_ut(config)
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
if __name__ == '__main__':
|
|
227
|
-
main()
|
|
@@ -34,8 +34,10 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, i
|
|
|
34
34
|
from msprobe.core.common.file_utils import check_link, FileChecker
|
|
35
35
|
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
|
|
36
36
|
from msprobe.core.common.const import FileCheckConst, Const
|
|
37
|
+
from msprobe.core.common.utils import check_op_str_pattern_valid
|
|
37
38
|
from msprobe.pytorch.common.log import logger
|
|
38
39
|
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
40
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
39
41
|
|
|
40
42
|
|
|
41
43
|
def check_tensor_overflow(x):
|
|
@@ -63,6 +65,7 @@ def check_tensor_overflow(x):
|
|
|
63
65
|
return False
|
|
64
66
|
|
|
65
67
|
|
|
68
|
+
@recursion_depth_decorator("check_data_overflow")
|
|
66
69
|
def check_data_overflow(x, device):
|
|
67
70
|
if isinstance(x, (tuple, list)):
|
|
68
71
|
if not x:
|
|
@@ -75,6 +78,7 @@ def check_data_overflow(x, device):
|
|
|
75
78
|
return torch_npu.npu.utils.npu_check_overflow(x)
|
|
76
79
|
|
|
77
80
|
|
|
81
|
+
@recursion_depth_decorator("is_bool_output")
|
|
78
82
|
def is_bool_output(x):
|
|
79
83
|
if isinstance(x, (tuple, list)):
|
|
80
84
|
if not x:
|
|
@@ -91,6 +95,7 @@ def run_overflow_check(forward_file):
|
|
|
91
95
|
dump_path = os.path.dirname(forward_file)
|
|
92
96
|
real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA)
|
|
93
97
|
for api_full_name, api_info_dict in tqdm(forward_content.items()):
|
|
98
|
+
check_op_str_pattern_valid(api_full_name)
|
|
94
99
|
if is_unsupported_api(api_full_name, is_overflow_check=True):
|
|
95
100
|
continue
|
|
96
101
|
try:
|
|
@@ -161,6 +166,7 @@ def _run_overflow_check(parser=None):
|
|
|
161
166
|
_run_overflow_check_parser(parser)
|
|
162
167
|
args = parser.parse_args(sys.argv[1:])
|
|
163
168
|
_run_overflow_check_command(args)
|
|
169
|
+
logger.info("UT task completed.")
|
|
164
170
|
|
|
165
171
|
|
|
166
172
|
def _run_overflow_check_command(args):
|
|
@@ -175,8 +181,3 @@ def _run_overflow_check_command(args):
|
|
|
175
181
|
logger.error(f"Set NPU device id failed. device id is: {args.device_id}")
|
|
176
182
|
raise NotImplementedError from error
|
|
177
183
|
run_overflow_check(api_info)
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
if __name__ == '__main__':
|
|
181
|
-
_run_overflow_check()
|
|
182
|
-
logger.info("UT task completed.")
|