mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -1,3 +1,20 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
1
18
|
import subprocess
|
|
2
19
|
import json
|
|
3
20
|
import os
|
|
@@ -16,9 +33,10 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
|
|
|
16
33
|
from msprobe.pytorch.common import parse_json_info_forward_backward
|
|
17
34
|
from msprobe.pytorch.common.log import logger
|
|
18
35
|
from msprobe.core.common.file_utils import FileChecker, check_file_suffix, check_link, FileOpen, \
|
|
19
|
-
|
|
36
|
+
create_directory, load_json, save_json
|
|
20
37
|
from msprobe.core.common.file_utils import remove_path
|
|
21
|
-
from msprobe.core.common.const import FileCheckConst
|
|
38
|
+
from msprobe.core.common.const import FileCheckConst, Const
|
|
39
|
+
from msprobe.core.common.utils import CompareException
|
|
22
40
|
|
|
23
41
|
|
|
24
42
|
def split_json_file(input_file, num_splits, filter_api):
|
|
@@ -30,9 +48,11 @@ def split_json_file(input_file, num_splits, filter_api):
|
|
|
30
48
|
for data_name in list(backward_data.keys()):
|
|
31
49
|
backward_data[f"{data_name}.backward"] = backward_data.pop(data_name)
|
|
32
50
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
51
|
+
input_data = load_json(input_file)
|
|
52
|
+
if input_data.get("data") is None:
|
|
53
|
+
logger.error("Invalid input file, 'data' field is missing")
|
|
54
|
+
raise CompareException("Invalid input file, 'data' field is missing")
|
|
55
|
+
input_data.pop("data")
|
|
36
56
|
|
|
37
57
|
items = list(forward_data.items())
|
|
38
58
|
total_items = len(items)
|
|
@@ -52,8 +72,7 @@ def split_json_file(input_file, num_splits, filter_api):
|
|
|
52
72
|
}
|
|
53
73
|
}
|
|
54
74
|
split_filename = f"temp_part{i}.json"
|
|
55
|
-
|
|
56
|
-
json.dump(temp_data, split_file)
|
|
75
|
+
save_json(split_filename, temp_data)
|
|
57
76
|
split_files.append(split_filename)
|
|
58
77
|
|
|
59
78
|
return split_files, total_items
|
|
@@ -105,7 +124,7 @@ def run_parallel_ut(config):
|
|
|
105
124
|
if output == '':
|
|
106
125
|
break
|
|
107
126
|
if '[ERROR]' in output:
|
|
108
|
-
|
|
127
|
+
logger.warning(output)
|
|
109
128
|
sys.stdout.flush()
|
|
110
129
|
except ValueError as e:
|
|
111
130
|
logger.warning(f"An error occurred while reading subprocess output: {e}")
|
|
@@ -119,7 +138,8 @@ def run_parallel_ut(config):
|
|
|
119
138
|
|
|
120
139
|
for api_info in config.api_files:
|
|
121
140
|
cmd = create_cmd(api_info, next(device_id_cycle))
|
|
122
|
-
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
|
|
141
|
+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
|
|
142
|
+
text=True, bufsize=1, shell=False)
|
|
123
143
|
processes.append(process)
|
|
124
144
|
threading.Thread(target=read_process_output, args=(process,), daemon=True).start()
|
|
125
145
|
|
|
@@ -150,7 +170,8 @@ def run_parallel_ut(config):
|
|
|
150
170
|
logger.error(f"An unexpected error occurred: {e}")
|
|
151
171
|
finally:
|
|
152
172
|
if progress_bar.n < config.total_items:
|
|
153
|
-
logger.warning("The UT task has not been completed. The parameter '-csv_path' along with the path to
|
|
173
|
+
logger.warning("The UT task has not been completed. The parameter '-csv_path' along with the path to " \
|
|
174
|
+
"the result CSV file will be utilized to resume the UT task.")
|
|
154
175
|
clean_up()
|
|
155
176
|
progress_bar_thread.join()
|
|
156
177
|
try:
|
|
@@ -163,17 +184,21 @@ def run_parallel_ut(config):
|
|
|
163
184
|
|
|
164
185
|
|
|
165
186
|
def prepare_config(args):
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
out_path =
|
|
170
|
-
check_path_before_create(out_path)
|
|
187
|
+
api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
|
|
188
|
+
ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
189
|
+
api_info = api_info_file_checker.common_check()
|
|
190
|
+
out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
|
|
171
191
|
create_directory(out_path)
|
|
172
192
|
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
|
|
173
193
|
out_path = out_path_checker.common_check()
|
|
174
194
|
split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
|
|
175
|
-
config_path =
|
|
176
|
-
|
|
195
|
+
config_path = args.config_path if args.config_path else None
|
|
196
|
+
if config_path:
|
|
197
|
+
config_path_checker = FileChecker(config_path, FileCheckConst.FILE,
|
|
198
|
+
FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
|
|
199
|
+
config_path = config_path_checker.common_check()
|
|
200
|
+
result_csv_path = args.result_csv_path or os.path.join(
|
|
201
|
+
out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
|
|
177
202
|
if not args.result_csv_path:
|
|
178
203
|
details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv")
|
|
179
204
|
comparator = Comparator(result_csv_path, details_csv_path, False)
|
|
@@ -190,7 +215,8 @@ def prepare_config(args):
|
|
|
190
215
|
def main():
|
|
191
216
|
parser = argparse.ArgumentParser(description='Run UT in parallel')
|
|
192
217
|
_run_ut_parser(parser)
|
|
193
|
-
parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
|
|
218
|
+
parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
|
|
219
|
+
help='Number of splits for parallel processing. Range: 1-64')
|
|
194
220
|
args = parser.parse_args()
|
|
195
221
|
config = prepare_config(args)
|
|
196
222
|
run_parallel_ut(config)
|
|
@@ -1,3 +1,20 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
1
18
|
import argparse
|
|
2
19
|
import os
|
|
3
20
|
import sys
|
|
@@ -11,11 +28,12 @@ else:
|
|
|
11
28
|
import torch
|
|
12
29
|
from tqdm import tqdm
|
|
13
30
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import generate_device_params, get_api_info
|
|
14
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api
|
|
15
|
-
from msprobe.core.common.file_utils import check_link
|
|
31
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, is_unsupported_api
|
|
32
|
+
from msprobe.core.common.file_utils import check_link, FileChecker
|
|
33
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
|
|
34
|
+
from msprobe.core.common.const import FileCheckConst, Const
|
|
16
35
|
from msprobe.pytorch.common.log import logger
|
|
17
36
|
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
18
|
-
from msprobe.core.common.const import Const
|
|
19
37
|
|
|
20
38
|
|
|
21
39
|
def check_tensor_overflow(x):
|
|
@@ -24,8 +42,8 @@ def check_tensor_overflow(x):
|
|
|
24
42
|
tensor_max = x.cpu().detach().float().numpy().tolist()
|
|
25
43
|
tensor_min = tensor_max
|
|
26
44
|
else:
|
|
27
|
-
tensor_max = torch.
|
|
28
|
-
tensor_min = torch.
|
|
45
|
+
tensor_max = torch.max(x).cpu().detach().float().numpy().tolist()
|
|
46
|
+
tensor_min = torch.min(x).cpu().detach().float().numpy().tolist()
|
|
29
47
|
# inf
|
|
30
48
|
if tensor_max == float('inf') or tensor_min == float('-inf'):
|
|
31
49
|
return True
|
|
@@ -57,23 +75,25 @@ def run_overflow_check(forward_file):
|
|
|
57
75
|
logger.info("start UT test")
|
|
58
76
|
forward_content, _, real_data_path = parse_json_info_forward_backward(forward_file)
|
|
59
77
|
for api_full_name, api_info_dict in tqdm(forward_content.items()):
|
|
78
|
+
if is_unsupported_api(api_full_name, is_overflow_check=True):
|
|
79
|
+
continue
|
|
60
80
|
try:
|
|
61
81
|
run_torch_api(api_full_name, api_info_dict, real_data_path)
|
|
62
82
|
except Exception as err:
|
|
63
83
|
_, api_name, _ = api_full_name.split(Const.SEP)
|
|
64
84
|
if "not implemented for 'Half'" in str(err):
|
|
65
|
-
logger.warning(f"API {api_name} not support half tensor in CPU
|
|
66
|
-
|
|
85
|
+
logger.warning(f"API {api_name} not support half tensor in CPU. This API does not support overflow "
|
|
86
|
+
"check, so it will be skipped.")
|
|
67
87
|
elif "expected scalar type Long" in str(err):
|
|
68
88
|
logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
|
|
69
|
-
|
|
89
|
+
"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
|
|
70
90
|
else:
|
|
71
91
|
logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
|
|
72
92
|
|
|
73
93
|
|
|
74
94
|
def run_torch_api(api_full_name, api_info_dict, real_data_path):
|
|
75
95
|
torch.npu.clear_npu_overflow_flag()
|
|
76
|
-
api_type, api_name
|
|
96
|
+
api_type, api_name = extract_basic_api_segments(api_full_name)
|
|
77
97
|
args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
|
|
78
98
|
if not need_grad:
|
|
79
99
|
logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward."
|
|
@@ -118,8 +138,9 @@ def _run_overflow_check(parser=None):
|
|
|
118
138
|
def _run_overflow_check_command(args):
|
|
119
139
|
torch.npu.set_compile_mode(jit_compile=args.jit_compile)
|
|
120
140
|
npu_device = "npu:" + str(args.device_id)
|
|
121
|
-
|
|
122
|
-
|
|
141
|
+
api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
|
|
142
|
+
ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
143
|
+
api_info = api_info_file_checker.common_check()
|
|
123
144
|
try:
|
|
124
145
|
torch.npu.set_device(npu_device)
|
|
125
146
|
except Exception as error:
|
|
@@ -1,6 +1,23 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
1
18
|
import argparse
|
|
2
19
|
import os
|
|
3
|
-
import
|
|
20
|
+
import re
|
|
4
21
|
import sys
|
|
5
22
|
import time
|
|
6
23
|
import gc
|
|
@@ -17,43 +34,34 @@ else:
|
|
|
17
34
|
import torch
|
|
18
35
|
from tqdm import tqdm
|
|
19
36
|
|
|
20
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import
|
|
21
|
-
get_validated_result_csv_path, get_validated_details_csv_path, exec_api
|
|
37
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import BackwardMessage, UtDataInfo, \
|
|
38
|
+
get_validated_result_csv_path, get_validated_details_csv_path, exec_api, record_skip_info, is_unsupported_api
|
|
22
39
|
from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
|
|
23
40
|
from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \
|
|
24
41
|
initialize_save_path, UtDataProcessor, extract_basic_api_segments, ApiData
|
|
25
42
|
from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
|
|
26
43
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
|
|
27
|
-
from msprobe.pytorch.api_accuracy_checker.common.config import
|
|
44
|
+
from msprobe.pytorch.api_accuracy_checker.common.config import CheckerConfig
|
|
28
45
|
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
29
|
-
from msprobe.core.common.file_utils import
|
|
30
|
-
|
|
46
|
+
from msprobe.core.common.file_utils import FileChecker, change_mode, \
|
|
47
|
+
create_directory, get_json_contents, read_csv, check_file_or_directory_path, check_crt_valid
|
|
31
48
|
from msprobe.pytorch.common.log import logger
|
|
32
49
|
from msprobe.pytorch.pt_config import parse_json_config
|
|
33
50
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
51
|
+
from msprobe.core.common.utils import safe_get_value
|
|
34
52
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
|
|
35
53
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
|
|
54
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params
|
|
36
55
|
|
|
37
56
|
|
|
38
57
|
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
39
58
|
UT_ERROR_DATA_DIR = 'ut_error_data' + current_time
|
|
40
59
|
RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
|
|
41
60
|
DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
|
|
42
|
-
RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
|
|
43
|
-
'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
|
|
44
|
-
'black_list', 'error_data_path', 'online_config'])
|
|
45
61
|
|
|
46
|
-
OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
|
|
47
62
|
|
|
48
63
|
not_backward_list = ['repeat_interleave']
|
|
49
|
-
not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
|
|
50
|
-
not_raise_dtype_set = {'type_as'}
|
|
51
64
|
|
|
52
|
-
RAISE_PRECISION = {
|
|
53
|
-
torch.float16: torch.float32,
|
|
54
|
-
torch.bfloat16: torch.float32,
|
|
55
|
-
torch.float32: torch.float64
|
|
56
|
-
}
|
|
57
65
|
|
|
58
66
|
tqdm_params = {
|
|
59
67
|
'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1
|
|
@@ -71,98 +79,6 @@ tqdm_params = {
|
|
|
71
79
|
}
|
|
72
80
|
|
|
73
81
|
|
|
74
|
-
def deal_detach(arg, to_detach=True):
|
|
75
|
-
return arg.detach() if to_detach else arg
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
def raise_bench_data_dtype(api_name, arg, raise_dtype=None):
|
|
79
|
-
'''
|
|
80
|
-
将标杆数据的dtype转换为raise_dtype
|
|
81
|
-
输入:
|
|
82
|
-
api_name:api名称
|
|
83
|
-
arg:标杆输入
|
|
84
|
-
raise_dtype:需要转换的dtype
|
|
85
|
-
输出:
|
|
86
|
-
arg: 转换dtype的标杆输入
|
|
87
|
-
'''
|
|
88
|
-
if api_name in hf_32_standard_api and arg.dtype == torch.float32:
|
|
89
|
-
return arg
|
|
90
|
-
if raise_dtype is None or arg.dtype not in RAISE_PRECISION or raise_dtype == arg.dtype:
|
|
91
|
-
return arg
|
|
92
|
-
return arg.type(raise_dtype)
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def generate_device_params(input_args, input_kwargs, need_backward, api_name):
|
|
96
|
-
def recursive_arg_to_device(arg_in, to_detach):
|
|
97
|
-
if isinstance(arg_in, (list, tuple)):
|
|
98
|
-
return type(arg_in)(recursive_arg_to_device(arg, to_detach) for arg in arg_in)
|
|
99
|
-
elif isinstance(arg_in, torch.Tensor):
|
|
100
|
-
if need_backward and arg_in.requires_grad:
|
|
101
|
-
arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_()
|
|
102
|
-
temp_arg_in = arg_in * 1
|
|
103
|
-
arg_in = temp_arg_in.type_as(arg_in)
|
|
104
|
-
arg_in.retain_grad()
|
|
105
|
-
return arg_in
|
|
106
|
-
else:
|
|
107
|
-
return deal_detach(arg_in.clone(), to_detach).to(current_device)
|
|
108
|
-
else:
|
|
109
|
-
return arg_in
|
|
110
|
-
|
|
111
|
-
is_detach = api_name not in not_detach_set
|
|
112
|
-
device_args = recursive_arg_to_device(input_args, is_detach)
|
|
113
|
-
device_kwargs = \
|
|
114
|
-
{key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()}
|
|
115
|
-
return device_args, device_kwargs
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
|
|
119
|
-
def recursive_arg_to_cpu(arg_in, to_detach, raise_dtype=None):
|
|
120
|
-
if isinstance(arg_in, (list, tuple)):
|
|
121
|
-
return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype) for arg in arg_in)
|
|
122
|
-
elif isinstance(arg_in, torch.Tensor):
|
|
123
|
-
if need_backward and arg_in.requires_grad:
|
|
124
|
-
arg_in = deal_detach(raise_bench_data_dtype(
|
|
125
|
-
api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
|
|
126
|
-
temp_arg_in = arg_in * 1
|
|
127
|
-
arg_in = temp_arg_in.type_as(arg_in)
|
|
128
|
-
arg_in.retain_grad()
|
|
129
|
-
return arg_in
|
|
130
|
-
else:
|
|
131
|
-
return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach)
|
|
132
|
-
else:
|
|
133
|
-
return arg_in
|
|
134
|
-
|
|
135
|
-
def is_tensor_with_raise_precision(arg_in, check_kwargs=False):
|
|
136
|
-
if arg_in.dtype in RAISE_PRECISION:
|
|
137
|
-
return True
|
|
138
|
-
if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]:
|
|
139
|
-
return True
|
|
140
|
-
return False
|
|
141
|
-
|
|
142
|
-
def recursive_find_dtypes(arg_in, kwargs=None, check_kwargs=False):
|
|
143
|
-
if isinstance(arg_in, (list, tuple)):
|
|
144
|
-
return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs) for arg in arg_in))
|
|
145
|
-
elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
|
|
146
|
-
return set([arg_in.dtype])
|
|
147
|
-
elif isinstance(arg_in, dict) and check_kwargs:
|
|
148
|
-
return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True) for v in arg_in.values()))
|
|
149
|
-
return set()
|
|
150
|
-
|
|
151
|
-
raise_dtype = None
|
|
152
|
-
need_raise_dtypes = recursive_find_dtypes(input_args)
|
|
153
|
-
need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
|
|
154
|
-
if len(need_raise_dtypes) == 1:
|
|
155
|
-
raise_dtype = RAISE_PRECISION.get(need_raise_dtypes.pop(), torch.float32)
|
|
156
|
-
elif len(need_raise_dtypes) >= 2:
|
|
157
|
-
raise_dtype = torch.float32
|
|
158
|
-
|
|
159
|
-
raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
|
|
160
|
-
is_detach = api_name not in not_detach_set
|
|
161
|
-
cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
|
|
162
|
-
cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()}
|
|
163
|
-
return cpu_args, cpu_kwargs
|
|
164
|
-
|
|
165
|
-
|
|
166
82
|
def run_ut(config):
|
|
167
83
|
logger.info("start UT test")
|
|
168
84
|
if config.online_config.is_online:
|
|
@@ -179,10 +95,12 @@ def run_ut(config):
|
|
|
179
95
|
if config.online_config.is_online:
|
|
180
96
|
run_api_online(config, compare)
|
|
181
97
|
else:
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
98
|
+
csv_df = read_csv(config.result_csv_path)
|
|
99
|
+
try:
|
|
100
|
+
api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
|
|
101
|
+
except IndexError:
|
|
102
|
+
logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
|
|
103
|
+
api_name_set = set()
|
|
186
104
|
run_api_offline(config, compare, api_name_set)
|
|
187
105
|
for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
|
|
188
106
|
change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
@@ -198,17 +116,23 @@ def run_api_offline(config, compare, api_name_set):
|
|
|
198
116
|
if api_full_name in api_name_set:
|
|
199
117
|
continue
|
|
200
118
|
if is_unsupported_api(api_full_name):
|
|
119
|
+
skip_message = f"API {api_full_name} not support for run ut. SKIP."
|
|
120
|
+
compare_alg_results = err_column.to_column_value(CompareConst.SKIP, skip_message)
|
|
121
|
+
record_skip_info(api_full_name, compare, compare_alg_results)
|
|
201
122
|
continue
|
|
202
123
|
_, api_name = extract_basic_api_segments(api_full_name)
|
|
203
124
|
if not api_name:
|
|
204
125
|
err_message = f"API {api_full_name} not support for run ut. SKIP."
|
|
205
126
|
logger.error(err_message)
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
compare.record_results(result_info)
|
|
127
|
+
compare_alg_results = err_column.to_column_value(CompareConst.SKIP, err_message)
|
|
128
|
+
record_skip_info(api_full_name, compare, compare_alg_results)
|
|
209
129
|
continue
|
|
210
130
|
try:
|
|
211
131
|
if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
|
|
132
|
+
skip_message = f"API {api_name} in black list or not in white list. SKIP."
|
|
133
|
+
logger.info(skip_message)
|
|
134
|
+
compare_alg_results = err_column.to_column_value(CompareConst.SKIP, skip_message)
|
|
135
|
+
record_skip_info(api_full_name, compare, compare_alg_results)
|
|
212
136
|
continue
|
|
213
137
|
data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict)
|
|
214
138
|
is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info)
|
|
@@ -217,12 +141,11 @@ def run_api_offline(config, compare, api_name_set):
|
|
|
217
141
|
except Exception as err:
|
|
218
142
|
if "expected scalar type Long" in str(err):
|
|
219
143
|
logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
|
|
220
|
-
|
|
144
|
+
"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
|
|
221
145
|
else:
|
|
222
146
|
logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
compare.record_results(result_info)
|
|
147
|
+
compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err))
|
|
148
|
+
record_skip_info(api_full_name, compare, compare_alg_results)
|
|
226
149
|
finally:
|
|
227
150
|
if is_gpu:
|
|
228
151
|
torch.cuda.empty_cache()
|
|
@@ -298,14 +221,6 @@ def blacklist_and_whitelist_filter(api_name, black_list, white_list):
|
|
|
298
221
|
return False
|
|
299
222
|
|
|
300
223
|
|
|
301
|
-
def is_unsupported_api(api_name):
|
|
302
|
-
split_name = api_name.split(Const.SEP)[0]
|
|
303
|
-
flag = split_name == Const.DISTRIBUTED
|
|
304
|
-
if flag:
|
|
305
|
-
logger.info(f"{split_name} api is not supported for run ut. SKIP.")
|
|
306
|
-
return flag
|
|
307
|
-
|
|
308
|
-
|
|
309
224
|
def do_save_error_data(api_full_name, data_info, error_data_path, is_fwd_success, is_bwd_success):
|
|
310
225
|
if not is_fwd_success or not is_bwd_success:
|
|
311
226
|
processor = UtDataProcessor(error_data_path)
|
|
@@ -327,12 +242,12 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
327
242
|
in_fwd_data_list.append(kwargs)
|
|
328
243
|
need_backward = api_full_name in backward_content
|
|
329
244
|
if not need_grad:
|
|
330
|
-
logger.warning("%s %s" % (api_full_name,
|
|
331
|
-
backward_message +=
|
|
245
|
+
logger.warning("%s %s" % (api_full_name, BackwardMessage.UNSUPPORT_BACKWARD_MESSAGE))
|
|
246
|
+
backward_message += BackwardMessage.UNSUPPORT_BACKWARD_MESSAGE
|
|
332
247
|
if api_name in not_backward_list:
|
|
333
248
|
need_grad = False
|
|
334
|
-
logger.
|
|
335
|
-
backward_message +=
|
|
249
|
+
logger.info("%s %s" % (api_full_name, BackwardMessage.NO_BACKWARD_RESULT_MESSAGE))
|
|
250
|
+
backward_message += BackwardMessage.NO_BACKWARD_RESULT_MESSAGE
|
|
336
251
|
need_backward = need_backward and need_grad
|
|
337
252
|
if kwargs.get("device"):
|
|
338
253
|
del kwargs["device"]
|
|
@@ -353,16 +268,20 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
353
268
|
if need_backward:
|
|
354
269
|
if need_to_backward(grad_index, out):
|
|
355
270
|
backward_args = backward_content[api_full_name].get("input")
|
|
356
|
-
|
|
271
|
+
func_options = {
|
|
272
|
+
'real_data_path': real_data_path
|
|
273
|
+
}
|
|
274
|
+
grad = gen_args(backward_args, api_name, func_options)
|
|
275
|
+
grad = safe_get_value(grad, 0, "grad")
|
|
357
276
|
bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
|
|
358
277
|
bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
|
|
359
278
|
device_grad = grad.clone().detach().to(current_device)
|
|
360
279
|
device_grad_out = run_backward(device_args, device_grad, grad_index, device_out)
|
|
361
280
|
else:
|
|
362
|
-
backward_message +=
|
|
281
|
+
backward_message += BackwardMessage.MULTIPLE_BACKWARD_MESSAGE
|
|
363
282
|
if api_name == "npu_fusion_attention":
|
|
364
|
-
out = out
|
|
365
|
-
device_out = device_out
|
|
283
|
+
out = safe_get_value(out, 0, "out")
|
|
284
|
+
device_out = safe_get_value(device_out, 0, "device_out")
|
|
366
285
|
|
|
367
286
|
return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
|
|
368
287
|
|
|
@@ -398,6 +317,9 @@ def need_to_backward(grad_index, out):
|
|
|
398
317
|
|
|
399
318
|
def run_backward(args, grad, grad_index, out):
|
|
400
319
|
if grad_index is not None:
|
|
320
|
+
if grad_index >= len(out):
|
|
321
|
+
logger.error(f"Run backward error when grad_index is {grad_index}")
|
|
322
|
+
raise IndexError(f"Run backward error when grad_index is {grad_index}")
|
|
401
323
|
out[grad_index].backward(grad)
|
|
402
324
|
else:
|
|
403
325
|
out.backward(grad)
|
|
@@ -411,12 +333,11 @@ def run_backward(args, grad, grad_index, out):
|
|
|
411
333
|
|
|
412
334
|
|
|
413
335
|
def initialize_save_error_data(error_data_path):
|
|
414
|
-
check_path_before_create(error_data_path)
|
|
415
336
|
create_directory(error_data_path)
|
|
416
337
|
error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
|
|
417
338
|
ability=FileCheckConst.WRITE_ABLE)
|
|
418
339
|
error_data_path = error_data_path_checker.common_check()
|
|
419
|
-
error_data_path =initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
|
|
340
|
+
error_data_path = initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
|
|
420
341
|
return error_data_path
|
|
421
342
|
|
|
422
343
|
|
|
@@ -477,7 +398,8 @@ def preprocess_forward_content(forward_content):
|
|
|
477
398
|
if key not in arg_cache:
|
|
478
399
|
filtered_new_args = [
|
|
479
400
|
{k: v for k, v in arg.items() if k not in ['Max', 'Min']}
|
|
480
|
-
for arg in value['input_args']
|
|
401
|
+
for arg in value['input_args']
|
|
402
|
+
if isinstance(arg, dict)
|
|
481
403
|
]
|
|
482
404
|
arg_cache[key] = (filtered_new_args, value['input_kwargs'])
|
|
483
405
|
|
|
@@ -512,7 +434,49 @@ def _run_ut(parser=None):
|
|
|
512
434
|
run_ut_command(args)
|
|
513
435
|
|
|
514
436
|
|
|
437
|
+
def checked_online_config(online_config):
|
|
438
|
+
if not online_config.is_online:
|
|
439
|
+
return
|
|
440
|
+
if not isinstance(online_config.is_online, bool):
|
|
441
|
+
raise ValueError("is_online must be bool type")
|
|
442
|
+
# rank_list
|
|
443
|
+
if not isinstance(online_config.rank_list, list):
|
|
444
|
+
raise ValueError("rank_list must be a list")
|
|
445
|
+
if online_config.rank_list and not all(isinstance(rank, int) for rank in online_config.rank_list):
|
|
446
|
+
raise ValueError("All elements in rank_list must be integers")
|
|
447
|
+
|
|
448
|
+
# nfs_path
|
|
449
|
+
if online_config.nfs_path:
|
|
450
|
+
check_file_or_directory_path(online_config.nfs_path, isdir=True)
|
|
451
|
+
return
|
|
452
|
+
# tls_path
|
|
453
|
+
if online_config.tls_path:
|
|
454
|
+
check_file_or_directory_path(online_config.tls_path, isdir=True)
|
|
455
|
+
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
|
|
456
|
+
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
|
|
457
|
+
check_crt_valid(os.path.join(online_config.tls_path, "server.crt"))
|
|
458
|
+
|
|
459
|
+
# host and port
|
|
460
|
+
if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
|
|
461
|
+
raise Exception(f"host: {online_config.host} is invalid.")
|
|
462
|
+
if not isinstance(online_config.port, int) or not (0 < online_config.port <= 65535):
|
|
463
|
+
raise Exception(f"port: {online_config.port} is invalid, port range 0-65535.")
|
|
464
|
+
|
|
465
|
+
|
|
515
466
|
def run_ut_command(args):
|
|
467
|
+
if args.config_path:
|
|
468
|
+
config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
|
|
469
|
+
FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
|
|
470
|
+
checked_config_path = config_path_checker.common_check()
|
|
471
|
+
_, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
|
|
472
|
+
checker_config = CheckerConfig(task_config)
|
|
473
|
+
else:
|
|
474
|
+
checker_config = CheckerConfig()
|
|
475
|
+
|
|
476
|
+
if not checker_config.is_online and not args.api_info_file:
|
|
477
|
+
logger.error("Please provide api_info_file for offline run ut.")
|
|
478
|
+
raise Exception("Please provide api_info_file for offline run ut.")
|
|
479
|
+
|
|
516
480
|
if not is_gpu:
|
|
517
481
|
torch.npu.set_compile_mode(jit_compile=args.jit_compile)
|
|
518
482
|
used_device = current_device + ":" + str(args.device_id[0])
|
|
@@ -529,17 +493,16 @@ def run_ut_command(args):
|
|
|
529
493
|
# 离线场景下,forward_content, backward_content, real_data_path从api_info_file中解析
|
|
530
494
|
forward_content, backward_content, real_data_path = None, None, None
|
|
531
495
|
if args.api_info_file:
|
|
532
|
-
api_info_file_checker = FileChecker(file_path
|
|
533
|
-
ability
|
|
496
|
+
api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
|
|
497
|
+
ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
534
498
|
checked_api_info = api_info_file_checker.common_check()
|
|
535
499
|
forward_content, backward_content, real_data_path = parse_json_info_forward_backward(checked_api_info)
|
|
536
500
|
if args.filter_api:
|
|
537
|
-
logger.info("Start filtering the api in the
|
|
501
|
+
logger.info("Start filtering the api in the api_info_file.")
|
|
538
502
|
forward_content = preprocess_forward_content(forward_content)
|
|
539
|
-
logger.info("Finish filtering the api in the
|
|
503
|
+
logger.info("Finish filtering the api in the api_info_file.")
|
|
540
504
|
|
|
541
|
-
out_path =
|
|
542
|
-
check_path_before_create(out_path)
|
|
505
|
+
out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
|
|
543
506
|
create_directory(out_path)
|
|
544
507
|
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
|
|
545
508
|
out_path = out_path_checker.common_check()
|
|
@@ -550,40 +513,27 @@ def run_ut_command(args):
|
|
|
550
513
|
if args.result_csv_path:
|
|
551
514
|
result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result')
|
|
552
515
|
details_csv_path = get_validated_details_csv_path(result_csv_path)
|
|
553
|
-
white_list = msCheckerConfig.white_list
|
|
554
|
-
black_list = msCheckerConfig.black_list
|
|
555
|
-
error_data_path = msCheckerConfig.error_data_path
|
|
556
|
-
is_online = msCheckerConfig.is_online
|
|
557
|
-
nfs_path = msCheckerConfig.nfs_path
|
|
558
|
-
host = msCheckerConfig.host
|
|
559
|
-
port = msCheckerConfig.port
|
|
560
|
-
rank_list = msCheckerConfig.rank_list
|
|
561
|
-
tls_path = msCheckerConfig.tls_path
|
|
562
|
-
if args.config_path:
|
|
563
|
-
config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
|
|
564
|
-
FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
|
|
565
|
-
checked_config_path = config_path_checker.common_check()
|
|
566
|
-
_, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
|
|
567
|
-
white_list = task_config.white_list
|
|
568
|
-
black_list = task_config.black_list
|
|
569
|
-
error_data_path = task_config.error_data_path
|
|
570
|
-
is_online = task_config.is_online
|
|
571
|
-
nfs_path = task_config.nfs_path
|
|
572
|
-
host = task_config.host
|
|
573
|
-
port = task_config.port
|
|
574
|
-
rank_list = task_config.rank_list
|
|
575
|
-
tls_path = task_config.tls_path
|
|
576
516
|
|
|
517
|
+
error_data_path = checker_config.error_data_path
|
|
577
518
|
if save_error_data:
|
|
578
519
|
if args.result_csv_path:
|
|
579
520
|
time_info = result_csv_path.split('.')[0].split('_')[-1]
|
|
580
521
|
global UT_ERROR_DATA_DIR
|
|
581
522
|
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
|
|
582
523
|
error_data_path = initialize_save_error_data(error_data_path)
|
|
583
|
-
online_config =
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
524
|
+
online_config = checker_config.get_online_config()
|
|
525
|
+
checked_online_config(online_config)
|
|
526
|
+
config_params = {
|
|
527
|
+
'forward_content': forward_content,
|
|
528
|
+
'backward_content': backward_content,
|
|
529
|
+
'result_csv_path': result_csv_path,
|
|
530
|
+
'details_csv_path': details_csv_path,
|
|
531
|
+
'save_error_data': save_error_data,
|
|
532
|
+
'is_continue_run_ut': args.result_csv_path,
|
|
533
|
+
'real_data_path': real_data_path,
|
|
534
|
+
'error_data_path': error_data_path
|
|
535
|
+
}
|
|
536
|
+
run_ut_config = checker_config.get_run_ut_config(**config_params)
|
|
587
537
|
run_ut(run_ut_config)
|
|
588
538
|
|
|
589
539
|
|