mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
msprobe/msprobe.py
CHANGED
|
@@ -15,13 +15,16 @@
|
|
|
15
15
|
|
|
16
16
|
import argparse
|
|
17
17
|
import sys
|
|
18
|
-
|
|
19
|
-
from msprobe.
|
|
20
|
-
from msprobe.
|
|
21
|
-
from msprobe.
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
18
|
+
import importlib.util
|
|
19
|
+
from msprobe.core.compare.utils import _compare_parser
|
|
20
|
+
from msprobe.core.common.log import logger
|
|
21
|
+
from msprobe.core.compare.compare_cli import compare_cli
|
|
22
|
+
from msprobe.core.common.const import Const
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def is_module_available(module_name):
|
|
26
|
+
spec = importlib.util.find_spec(module_name)
|
|
27
|
+
return spec is not None
|
|
25
28
|
|
|
26
29
|
|
|
27
30
|
def main():
|
|
@@ -31,37 +34,74 @@ def main():
|
|
|
31
34
|
"Providing one-site accuracy difference debugging toolkit for training on Ascend Devices.\n"
|
|
32
35
|
f"For any issue, refer README.md first",
|
|
33
36
|
)
|
|
37
|
+
|
|
34
38
|
parser.set_defaults(print_help=parser.print_help)
|
|
35
|
-
parser.add_argument('-f', '--framework', required=True, choices=[
|
|
39
|
+
parser.add_argument('-f', '--framework', required=True, choices=[Const.PT_FRAMEWORK, Const.MS_FRAMEWORK],
|
|
36
40
|
help='Deep learning framework.')
|
|
37
41
|
subparsers = parser.add_subparsers()
|
|
38
42
|
subparsers.add_parser('parse')
|
|
43
|
+
compare_cmd_parser = subparsers.add_parser('compare')
|
|
39
44
|
run_ut_cmd_parser = subparsers.add_parser('run_ut')
|
|
40
45
|
multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut')
|
|
41
46
|
api_precision_compare_cmd_parser = subparsers.add_parser('api_precision_compare')
|
|
42
47
|
run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check')
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
48
|
+
_compare_parser(compare_cmd_parser)
|
|
49
|
+
is_torch_available=is_module_available("torch")
|
|
50
|
+
is_mindspore_available = is_module_available("mindspore")
|
|
51
|
+
if is_torch_available:
|
|
52
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
|
|
53
|
+
from msprobe.pytorch.parse_tool.cli import parse as cli_parse
|
|
54
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut
|
|
55
|
+
from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, \
|
|
56
|
+
_api_precision_compare_command
|
|
57
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
|
|
58
|
+
_run_overflow_check_command
|
|
59
|
+
|
|
60
|
+
_run_ut_parser(run_ut_cmd_parser)
|
|
61
|
+
_run_ut_parser(multi_run_ut_cmd_parser)
|
|
62
|
+
multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
|
|
63
|
+
help='Number of splits for parallel processing. Range: 1-64')
|
|
64
|
+
_api_precision_compare_parser(api_precision_compare_cmd_parser)
|
|
65
|
+
_run_overflow_check_parser(run_overflow_check_cmd_parser)
|
|
66
|
+
elif is_mindspore_available:
|
|
67
|
+
from msprobe.mindspore.api_accuracy_checker.main import add_api_accuracy_checker_argument
|
|
68
|
+
add_api_accuracy_checker_argument(run_ut_cmd_parser)
|
|
69
|
+
|
|
49
70
|
if len(sys.argv) == 1:
|
|
50
71
|
parser.print_help()
|
|
51
72
|
sys.exit(0)
|
|
52
73
|
args = parser.parse_args(sys.argv[1:])
|
|
53
|
-
if sys.argv[
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
74
|
+
if sys.argv[2] == Const.PT_FRAMEWORK:
|
|
75
|
+
if not is_torch_available:
|
|
76
|
+
logger.error("PyTorch does not exist, please install PyTorch library")
|
|
77
|
+
raise Exception("PyTorch does not exist, please install PyTorch library")
|
|
78
|
+
if sys.argv[3] == "run_ut":
|
|
79
|
+
run_ut_command(args)
|
|
80
|
+
elif sys.argv[3] == "parse":
|
|
81
|
+
cli_parse()
|
|
82
|
+
elif sys.argv[3] == "multi_run_ut":
|
|
83
|
+
config = prepare_config(args)
|
|
84
|
+
run_parallel_ut(config)
|
|
85
|
+
elif sys.argv[3] == "api_precision_compare":
|
|
86
|
+
_api_precision_compare_command(args)
|
|
87
|
+
elif sys.argv[3] == "run_overflow_check":
|
|
88
|
+
_run_overflow_check_command(args)
|
|
89
|
+
elif sys.argv[3] == "compare":
|
|
90
|
+
if args.cell_mapping is not None or args.api_mapping is not None:
|
|
91
|
+
logger.error("Argument -cm or -am is not supported in PyTorch framework")
|
|
92
|
+
raise Exception("Argument -cm or -am is not supported in PyTorch framework")
|
|
93
|
+
compare_cli(args)
|
|
94
|
+
else:
|
|
95
|
+
if not is_module_available(Const.MS_FRAMEWORK):
|
|
96
|
+
logger.error("MindSpore does not exist, please install MindSpore library")
|
|
97
|
+
raise Exception("MindSpore does not exist, please install MindSpore library")
|
|
98
|
+
if sys.argv[3] == "compare":
|
|
99
|
+
if isinstance(args.api_mapping, str):
|
|
100
|
+
logger.warning("User defined mapping tables are not supported in the current version")
|
|
101
|
+
compare_cli(args)
|
|
102
|
+
elif sys.argv[3] == "run_ut":
|
|
103
|
+
from msprobe.mindspore.api_accuracy_checker.main import api_checker_main
|
|
104
|
+
api_checker_main(args)
|
|
65
105
|
|
|
66
106
|
if __name__ == "__main__":
|
|
67
107
|
main()
|
msprobe/pytorch/__init__.py
CHANGED
|
@@ -1,17 +1,14 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import yaml
|
|
3
|
-
from msprobe.
|
|
4
|
-
from msprobe.
|
|
5
|
-
from msprobe.
|
|
6
|
-
|
|
7
|
-
WrapApi = set(WrapFunctionalOps) | set(WrapTensorOps) | set(WrapTorchOps)
|
|
3
|
+
from msprobe.core.common.utils import check_file_or_directory_path
|
|
4
|
+
from msprobe.core.common.utils import load_yaml
|
|
5
|
+
from msprobe.pytorch.pt_config import RunUTConfig
|
|
8
6
|
|
|
9
7
|
|
|
10
8
|
class Config:
|
|
11
9
|
def __init__(self, yaml_file):
|
|
12
10
|
check_file_or_directory_path(yaml_file, False)
|
|
13
|
-
|
|
14
|
-
config = yaml.safe_load(file)
|
|
11
|
+
config = load_yaml(yaml_file)
|
|
15
12
|
self.config = {key: self.validate(key, value) for key, value in config.items()}
|
|
16
13
|
|
|
17
14
|
def __getattr__(self, item):
|
|
@@ -24,8 +21,15 @@ class Config:
|
|
|
24
21
|
def validate(key, value):
|
|
25
22
|
validators = {
|
|
26
23
|
'white_list': list,
|
|
24
|
+
'black_list': list,
|
|
27
25
|
'error_data_path': str,
|
|
28
|
-
'precision': int
|
|
26
|
+
'precision': int,
|
|
27
|
+
'is_online': bool,
|
|
28
|
+
'nfs_path': str,
|
|
29
|
+
'host': str,
|
|
30
|
+
'port': int,
|
|
31
|
+
'rank_list': list,
|
|
32
|
+
'tls_path': str
|
|
29
33
|
}
|
|
30
34
|
if key not in validators:
|
|
31
35
|
raise ValueError(f"{key} must be one of {validators.keys()}")
|
|
@@ -34,14 +38,15 @@ class Config:
|
|
|
34
38
|
if key == 'precision' and value < 0:
|
|
35
39
|
raise ValueError("precision must be greater than 0")
|
|
36
40
|
if key == 'white_list':
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
41
|
+
RunUTConfig.check_filter_list_config(key, value)
|
|
42
|
+
if key == 'black_list':
|
|
43
|
+
RunUTConfig.check_filter_list_config(key, value)
|
|
44
|
+
if key == 'error_data_path':
|
|
45
|
+
RunUTConfig.check_error_data_path_config(value)
|
|
46
|
+
if key == 'nfs_path':
|
|
47
|
+
RunUTConfig.check_nfs_path_config(value)
|
|
48
|
+
if key == 'tls_path':
|
|
49
|
+
RunUTConfig.check_tls_path_config(value)
|
|
45
50
|
return value
|
|
46
51
|
|
|
47
52
|
|
|
@@ -14,10 +14,8 @@
|
|
|
14
14
|
# See the License for the specific language governing permissions and
|
|
15
15
|
# limitations under the License.
|
|
16
16
|
"""
|
|
17
|
-
import json
|
|
18
17
|
import os
|
|
19
18
|
import re
|
|
20
|
-
import csv
|
|
21
19
|
|
|
22
20
|
import torch
|
|
23
21
|
|
|
@@ -38,12 +36,6 @@ class DumpException(CompareException):
|
|
|
38
36
|
pass
|
|
39
37
|
|
|
40
38
|
|
|
41
|
-
def write_csv(data, filepath):
|
|
42
|
-
with FileOpen(filepath, 'a', encoding='utf-8-sig') as f:
|
|
43
|
-
writer = csv.writer(f)
|
|
44
|
-
writer.writerows(data)
|
|
45
|
-
|
|
46
|
-
|
|
47
39
|
def check_object_type(check_object, allow_type):
|
|
48
40
|
"""
|
|
49
41
|
Function Description:
|
|
@@ -59,58 +51,6 @@ def check_object_type(check_object, allow_type):
|
|
|
59
51
|
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
60
52
|
|
|
61
53
|
|
|
62
|
-
def check_file_or_directory_path(path, isdir=False):
|
|
63
|
-
"""
|
|
64
|
-
Function Description:
|
|
65
|
-
check whether the path is valid
|
|
66
|
-
Parameter:
|
|
67
|
-
path: the path to check
|
|
68
|
-
isdir: the path is dir or file
|
|
69
|
-
Exception Description:
|
|
70
|
-
when invalid data throw exception
|
|
71
|
-
"""
|
|
72
|
-
if isdir:
|
|
73
|
-
if not os.path.exists(path):
|
|
74
|
-
logger.error('The path {} is not exist.'.format(path))
|
|
75
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
76
|
-
|
|
77
|
-
if not os.path.isdir(path):
|
|
78
|
-
logger.error('The path {} is not a directory.'.format(path))
|
|
79
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
80
|
-
|
|
81
|
-
if not os.access(path, os.W_OK):
|
|
82
|
-
logger.error(
|
|
83
|
-
'The path {} does not have permission to write. Please check the path permission'.format(path))
|
|
84
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
85
|
-
else:
|
|
86
|
-
if not os.path.isfile(path):
|
|
87
|
-
logger.error('{} is an invalid file or non-exist.'.format(path))
|
|
88
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
89
|
-
|
|
90
|
-
if not os.access(path, os.R_OK):
|
|
91
|
-
logger.error(
|
|
92
|
-
'The path {} does not have permission to read. Please check the path permission'.format(path))
|
|
93
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def get_json_contents(file_path):
|
|
97
|
-
ops = get_file_content_bytes(file_path)
|
|
98
|
-
try:
|
|
99
|
-
json_obj = json.loads(ops)
|
|
100
|
-
except ValueError as error:
|
|
101
|
-
logger.error('Failed to load "%s". %s' % (file_path, str(error)))
|
|
102
|
-
raise CompareException(CompareException.INVALID_FILE_ERROR) from error
|
|
103
|
-
if not isinstance(json_obj, dict):
|
|
104
|
-
logger.error('Json file %s, content is not a dictionary!' % file_path)
|
|
105
|
-
raise CompareException(CompareException.INVALID_FILE_ERROR)
|
|
106
|
-
return json_obj
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
def get_file_content_bytes(file):
|
|
110
|
-
with FileOpen(file, 'rb') as file_handle:
|
|
111
|
-
return file_handle.read()
|
|
112
|
-
|
|
113
|
-
|
|
114
54
|
class SoftlinkCheckException(Exception):
|
|
115
55
|
pass
|
|
116
56
|
|
|
@@ -166,6 +106,7 @@ def initialize_save_path(save_path, dir_name):
|
|
|
166
106
|
os.mkdir(data_path, mode=FileCheckConst.DATA_DIR_AUTHORITY)
|
|
167
107
|
data_path_checker = FileChecker(data_path, FileCheckConst.DIR)
|
|
168
108
|
data_path_checker.common_check()
|
|
109
|
+
return data_path
|
|
169
110
|
|
|
170
111
|
|
|
171
112
|
def write_pt(file_path, tensor):
|
|
@@ -6,9 +6,6 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAM
|
|
|
6
6
|
from msprobe.core.common.const import CompareConst
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
DEFAULT_THRESHOLD = 1
|
|
10
|
-
|
|
11
|
-
|
|
12
9
|
#cos
|
|
13
10
|
def cosine_sim(bench_output, device_output):
|
|
14
11
|
msg = ""
|
|
@@ -197,8 +194,8 @@ def check_norm_value(normal_value_mask, rel_err, rtol):
|
|
|
197
194
|
|
|
198
195
|
def get_ulp_err(bench_output, device_output, dtype):
|
|
199
196
|
parameters = ULP_PARAMETERS.get(dtype)
|
|
200
|
-
min_eb = parameters.get('min_eb'
|
|
201
|
-
exponent_num = parameters.get('exponent_num'
|
|
197
|
+
min_eb = parameters.get('min_eb')[0]
|
|
198
|
+
exponent_num = parameters.get('exponent_num')[0]
|
|
202
199
|
abs_bench = np.abs(bench_output)
|
|
203
200
|
eb = np.where(abs_bench == 0, 0, np.floor(np.log2(abs_bench)))
|
|
204
201
|
eb = np.maximum(eb, min_eb)
|
|
@@ -7,19 +7,19 @@ from collections import namedtuple
|
|
|
7
7
|
import torch
|
|
8
8
|
import pandas as pd
|
|
9
9
|
|
|
10
|
-
from msprobe.
|
|
10
|
+
from msprobe.core.common.utils import write_csv
|
|
11
11
|
from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
|
|
12
12
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
|
|
13
13
|
API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
|
|
14
|
-
ApiPrecisionCompareColumn,
|
|
14
|
+
ApiPrecisionCompareColumn, absolute_standard_api, binary_standard_api, ulp_standard_api, thousandth_standard_api, \
|
|
15
15
|
BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage, is_inf_or_nan, \
|
|
16
16
|
check_inf_or_nan
|
|
17
17
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
|
|
18
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.
|
|
18
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path
|
|
19
19
|
from msprobe.core.common.file_check import FileChecker, change_mode, check_path_before_create, create_directory
|
|
20
20
|
from msprobe.pytorch.common.log import logger
|
|
21
21
|
from msprobe.core.common.utils import CompareException
|
|
22
|
-
from msprobe.core.common.const import CompareConst, FileCheckConst
|
|
22
|
+
from msprobe.core.common.const import CompareConst, FileCheckConst, Const
|
|
23
23
|
|
|
24
24
|
CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
|
|
25
25
|
BenchmarkInf_Nan_Consistency = namedtuple('BenchmarkInf_Nan_Consistency', ['small_value_inf_nan_consistency',
|
|
@@ -289,15 +289,38 @@ def api_precision_compare(config):
|
|
|
289
289
|
change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
290
290
|
|
|
291
291
|
|
|
292
|
+
def online_api_precision_compare(online_config):
|
|
293
|
+
rank = online_config.rank
|
|
294
|
+
result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace("_rank*.csv", f"_rank{rank}.csv")
|
|
295
|
+
details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace("_rank*.csv", f"_rank{rank}.csv")
|
|
296
|
+
detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
|
|
297
|
+
result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
|
|
298
|
+
if not os.path.exists(result_csv_path):
|
|
299
|
+
write_csv(result_csv_title, result_csv_path)
|
|
300
|
+
if not os.path.exists(details_csv_path):
|
|
301
|
+
write_csv(detail_csv_title, details_csv_path)
|
|
302
|
+
config = CompareConfig("", "", result_csv_path, details_csv_path)
|
|
303
|
+
try:
|
|
304
|
+
npu_data, gpu_data = online_config.npu_data, online_config.gpu_data
|
|
305
|
+
check_csv_columns(npu_data.columns, "npu_csv")
|
|
306
|
+
check_csv_columns(gpu_data.columns, "gpu_csv")
|
|
307
|
+
analyse_csv(npu_data, gpu_data, config)
|
|
308
|
+
except Exception as err:
|
|
309
|
+
logger.error(f"Online api precision compare Error: {str(err)}")
|
|
310
|
+
change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
311
|
+
change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
312
|
+
|
|
313
|
+
|
|
292
314
|
def analyse_csv(npu_data, gpu_data, config):
|
|
293
315
|
forward_status, backward_status = [], []
|
|
294
|
-
last_api_name, last_api_dtype = None, None
|
|
316
|
+
last_api_name, last_api_dtype, last_api_full_name = None, None, None
|
|
295
317
|
for _, row_npu in npu_data.iterrows():
|
|
296
318
|
message = ''
|
|
297
319
|
compare_column = ApiPrecisionOutputColumn()
|
|
298
320
|
full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
|
|
299
321
|
row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status]
|
|
300
|
-
|
|
322
|
+
api_type, api_name, api_nums, direction_status, _, _ = full_api_name_with_direction_status.split(Const.SEP)
|
|
323
|
+
api_full_name = Const.SEP.join([api_type, api_name, api_nums])
|
|
301
324
|
if row_gpu.empty:
|
|
302
325
|
logger.warning(f'This API : {full_api_name_with_direction_status} does not exist in the GPU data.')
|
|
303
326
|
continue
|
|
@@ -315,14 +338,14 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
315
338
|
write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
|
|
316
339
|
else:
|
|
317
340
|
compare_column.api_name = full_api_name_with_direction_status
|
|
318
|
-
if api_name in
|
|
341
|
+
if api_name in thousandth_standard_api:
|
|
319
342
|
new_status = record_thousandth_threshold_result(compare_column, row_npu)
|
|
320
343
|
elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or \
|
|
321
|
-
api_name in
|
|
344
|
+
api_name in binary_standard_api:
|
|
322
345
|
new_status = record_binary_consistency_result(api_name, compare_column, row_npu)
|
|
323
|
-
elif api_name in
|
|
346
|
+
elif api_name in absolute_standard_api:
|
|
324
347
|
new_status = record_absolute_threshold_result(compare_column, row_npu)
|
|
325
|
-
elif api_name in
|
|
348
|
+
elif api_name in ulp_standard_api and \
|
|
326
349
|
row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in ULP_COMPARE_SUPPORT_LIST:
|
|
327
350
|
us = ULPStandard(full_api_name_with_direction_status, row_npu, row_gpu)
|
|
328
351
|
new_status = record_ulp_compare_result(compare_column, us)
|
|
@@ -335,6 +358,7 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
335
358
|
if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
|
|
336
359
|
message = unsupported_message
|
|
337
360
|
write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path)
|
|
361
|
+
print_test_success(api_full_name, "skip", "skip")
|
|
338
362
|
forward_status, backward_status = [], []
|
|
339
363
|
message = ''
|
|
340
364
|
else:
|
|
@@ -342,11 +366,13 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
342
366
|
backward_result = get_api_checker_result(backward_status)
|
|
343
367
|
message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
|
|
344
368
|
write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
|
|
369
|
+
print_test_success(api_full_name, forward_result, backward_result)
|
|
345
370
|
forward_status, backward_status = [], []
|
|
346
371
|
message = ''
|
|
347
372
|
|
|
348
373
|
is_supported = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in API_PRECISION_COMPARE_UNSUPPORT_LIST
|
|
349
374
|
last_api_name = api_name
|
|
375
|
+
last_api_full_name = api_full_name
|
|
350
376
|
|
|
351
377
|
last_api_dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
|
|
352
378
|
if not is_supported:
|
|
@@ -363,11 +389,21 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
363
389
|
if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
|
|
364
390
|
message = unsupported_message
|
|
365
391
|
write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path)
|
|
392
|
+
print_test_success(last_api_full_name, "skip", "skip")
|
|
366
393
|
else:
|
|
367
394
|
forward_result = get_api_checker_result(forward_status)
|
|
368
395
|
backward_result = get_api_checker_result(backward_status)
|
|
369
396
|
message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
|
|
370
397
|
write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
|
|
398
|
+
print_test_success(last_api_full_name, forward_result, backward_result)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def print_test_success(api_full_name, forward_result, backward_result):
|
|
402
|
+
is_fwd_success = (forward_result == CompareConst.PASS)
|
|
403
|
+
is_bwd_success = (backward_result == CompareConst.PASS or backward_result == CompareConst.SPACE)
|
|
404
|
+
logger.info(f"running api_full_name {api_full_name} compare, "
|
|
405
|
+
f"is_fwd_success: {is_fwd_success}, "
|
|
406
|
+
f"is_bwd_success: {is_bwd_success}")
|
|
371
407
|
|
|
372
408
|
|
|
373
409
|
def check_error_rate(npu_error_rate):
|
|
@@ -1,27 +1,28 @@
|
|
|
1
1
|
# 进行比对及结果展示
|
|
2
2
|
import os
|
|
3
3
|
from collections import namedtuple
|
|
4
|
-
|
|
4
|
+
|
|
5
5
|
import numpy as np
|
|
6
|
-
from msprobe.
|
|
7
|
-
|
|
8
|
-
from msprobe.
|
|
9
|
-
DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, \
|
|
10
|
-
ULPStandardApi, ThousandthStandardApi, apis_threshold
|
|
11
|
-
from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
|
|
6
|
+
from msprobe.core.common.utils import write_csv, get_json_contents, CompareException
|
|
7
|
+
import torch
|
|
8
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
12
9
|
from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \
|
|
13
10
|
get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \
|
|
14
11
|
get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \
|
|
15
12
|
check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err
|
|
16
13
|
from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
|
|
17
|
-
from msprobe.
|
|
14
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
|
|
15
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \
|
|
16
|
+
DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, absolute_standard_api, binary_standard_api, \
|
|
17
|
+
ulp_standard_api, thousandth_standard_api, apis_threshold
|
|
18
|
+
from msprobe.pytorch.common.log import logger
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status',
|
|
21
22
|
'fwd_compare_alg_results', 'bwd_compare_alg_results', 'rank'])
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
|
|
25
|
+
INDEX_TEST_RESULT_GROUP = 3
|
|
25
26
|
INDEX_FIRST_GROUP = 0
|
|
26
27
|
INDEX_MESSAGE = -1
|
|
27
28
|
|
|
@@ -33,20 +34,34 @@ class Comparator:
|
|
|
33
34
|
COLUMN_BACKWARD_SUCCESS = "Backward Test Success"
|
|
34
35
|
COLUMN_STACK_INFO = "Traceback callstack info"
|
|
35
36
|
|
|
36
|
-
def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None):
|
|
37
|
-
self.
|
|
38
|
-
self.
|
|
39
|
-
|
|
37
|
+
def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None, config=None):
|
|
38
|
+
self.save_path_str = result_csv_path
|
|
39
|
+
self.detail_save_path_str = details_csv_path
|
|
40
|
+
self.save_path_list = [result_csv_path]
|
|
41
|
+
self.detail_save_path_list = [details_csv_path]
|
|
42
|
+
|
|
43
|
+
if config and config.online_config.is_online:
|
|
44
|
+
self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv")
|
|
45
|
+
self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv")
|
|
46
|
+
self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list]
|
|
47
|
+
self.detail_save_path_list = \
|
|
48
|
+
[self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
|
|
49
|
+
|
|
50
|
+
if not is_continue_run_ut:
|
|
40
51
|
self.write_csv_title()
|
|
41
52
|
if stack_info_json_path:
|
|
42
53
|
self.stack_info = get_json_contents(stack_info_json_path)
|
|
43
54
|
else:
|
|
44
55
|
self.stack_info = None
|
|
45
56
|
|
|
57
|
+
@staticmethod
|
|
58
|
+
def get_path_from_rank(rank, path_list, path_pattern):
|
|
59
|
+
return path_list[-1] if len(path_list) == 1 else path_pattern.format(rank)
|
|
60
|
+
|
|
46
61
|
@staticmethod
|
|
47
62
|
def print_pretest_result():
|
|
48
63
|
logger.info("Successfully completed run_ut/multi_run_ut.")
|
|
49
|
-
|
|
64
|
+
|
|
50
65
|
@staticmethod
|
|
51
66
|
def _compare_dropout(bench_output, device_output):
|
|
52
67
|
tensor_num = bench_output.numel()
|
|
@@ -75,7 +90,7 @@ class Comparator:
|
|
|
75
90
|
error_rate = float(error_nums / bench_output.size)
|
|
76
91
|
result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
|
|
77
92
|
return error_rate, result, ""
|
|
78
|
-
|
|
93
|
+
|
|
79
94
|
@staticmethod
|
|
80
95
|
def _get_absolute_threshold_attribute(api_name, dtype):
|
|
81
96
|
small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value')
|
|
@@ -83,35 +98,18 @@ class Comparator:
|
|
|
83
98
|
rtol = apis_threshold.get(api_name).get(dtype).get('rtol')
|
|
84
99
|
return small_value_threshold, small_value_atol, rtol
|
|
85
100
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
if not os.path.exists(self.save_path):
|
|
90
|
-
write_csv(summary_test_rows, self.save_path)
|
|
91
|
-
if not os.path.exists(self.detail_save_path):
|
|
92
|
-
write_csv(DETAIL_TEST_ROWS, self.detail_save_path)
|
|
93
|
-
|
|
94
|
-
def write_summary_csv(self, test_result):
|
|
95
|
-
test_rows = []
|
|
96
|
-
if self.stack_info:
|
|
97
|
-
test_rows[0].append(self.COLUMN_STACK_INFO)
|
|
98
|
-
|
|
99
|
-
name = test_result[0]
|
|
100
|
-
df_row = list(test_result[:INDEX_TEST_RESULT__GROUP])
|
|
101
|
-
if test_result[1] == "SKIP":
|
|
102
|
-
df_row.append(test_result[INDEX_TEST_RESULT__GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
|
|
103
|
-
if self.stack_info:
|
|
104
|
-
stack_info = "\n".join(self.stack_info[name])
|
|
105
|
-
df_row.append(stack_info)
|
|
106
|
-
test_rows.append(df_row)
|
|
107
|
-
write_csv(test_rows, self.save_path)
|
|
108
|
-
|
|
109
|
-
def write_detail_csv(self, test_result):
|
|
101
|
+
@staticmethod
|
|
102
|
+
def _get_run_ut_detail(test_result):
|
|
103
|
+
"""get run_ut detail before write to csv, called by online run_ut"""
|
|
110
104
|
test_rows = []
|
|
105
|
+
try:
|
|
106
|
+
subject_prefix = test_result[0]
|
|
107
|
+
fwd_result = test_result[3]
|
|
108
|
+
bwd_result = test_result[4]
|
|
109
|
+
except IndexError as e:
|
|
110
|
+
logger.error("List index out of bounds when writing detail CSV.")
|
|
111
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR, "list index out of bounds") from e
|
|
111
112
|
|
|
112
|
-
subject_prefix = test_result[0]
|
|
113
|
-
fwd_result = test_result[3]
|
|
114
|
-
bwd_result = test_result[4]
|
|
115
113
|
if isinstance(fwd_result, list):
|
|
116
114
|
for i, test_subject in enumerate(fwd_result):
|
|
117
115
|
subject = subject_prefix + ".forward.output." + str(i)
|
|
@@ -124,14 +122,49 @@ class Comparator:
|
|
|
124
122
|
test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision)
|
|
125
123
|
if isinstance(item, float) else item for item in test_subject]
|
|
126
124
|
test_rows.append([subject] + list(test_subject))
|
|
125
|
+
return test_rows
|
|
127
126
|
|
|
128
|
-
|
|
127
|
+
def write_csv_title(self):
|
|
128
|
+
summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS,
|
|
129
|
+
self.COLUMN_BACKWARD_SUCCESS, "Message"]]
|
|
130
|
+
for save_path, detail_save_path in zip(self.save_path_list, self.detail_save_path_list):
|
|
131
|
+
if not os.path.exists(save_path):
|
|
132
|
+
write_csv(summary_test_rows, save_path)
|
|
133
|
+
if not os.path.exists(detail_save_path):
|
|
134
|
+
write_csv(DETAIL_TEST_ROWS, detail_save_path)
|
|
135
|
+
|
|
136
|
+
def write_summary_csv(self, test_result):
|
|
137
|
+
test_rows = []
|
|
138
|
+
try:
|
|
139
|
+
name = test_result[0]
|
|
140
|
+
df_row = list(test_result[:INDEX_TEST_RESULT_GROUP])
|
|
141
|
+
if test_result[1] == "SKIP":
|
|
142
|
+
df_row.append(test_result[INDEX_TEST_RESULT_GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
|
|
143
|
+
if self.stack_info:
|
|
144
|
+
stack_info = "\n".join(self.stack_info[name])
|
|
145
|
+
df_row.append(stack_info)
|
|
146
|
+
test_rows.append(df_row)
|
|
147
|
+
save_path = self.get_path_from_rank(test_result[-1], self.save_path_list, self.save_path_str)
|
|
148
|
+
except IndexError as e:
|
|
149
|
+
logger.error("List index out of bounds when writing summary CSV.")
|
|
150
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR, "list index out of bounds") from e
|
|
151
|
+
write_csv(test_rows, save_path)
|
|
152
|
+
|
|
153
|
+
def write_detail_csv(self, test_result):
|
|
154
|
+
test_rows = self._get_run_ut_detail(test_result)
|
|
155
|
+
detail_save_path = self.get_path_from_rank(test_result[-1],
|
|
156
|
+
self.detail_save_path_list,
|
|
157
|
+
self.detail_save_path_str)
|
|
158
|
+
write_csv(test_rows, detail_save_path)
|
|
129
159
|
|
|
130
160
|
def record_results(self, args):
|
|
131
161
|
self.write_summary_csv(args)
|
|
132
162
|
self.write_detail_csv(args)
|
|
133
163
|
|
|
134
|
-
def compare_output(self, full_api_name, data_info):
|
|
164
|
+
def compare_output(self, full_api_name, data_info, is_online=False):
|
|
165
|
+
"""Get compare result and write to result and detail csv.
|
|
166
|
+
is_online: bool, default False. True: called by online api precision compare, only compare without write to csv.
|
|
167
|
+
"""
|
|
135
168
|
_, api_name, _ = full_api_name.split(Const.SEP)
|
|
136
169
|
bench_output, device_output = data_info.bench_output, data_info.device_output
|
|
137
170
|
bench_grad, device_grad = data_info.bench_grad, data_info.device_grad
|
|
@@ -160,6 +193,9 @@ class Comparator:
|
|
|
160
193
|
fwd_compare_alg_results,
|
|
161
194
|
bwd_compare_alg_results,
|
|
162
195
|
data_info.rank)
|
|
196
|
+
if is_online:
|
|
197
|
+
# get run_ut compare detail
|
|
198
|
+
return self._get_run_ut_detail(result_info)
|
|
163
199
|
self.record_results(result_info)
|
|
164
200
|
return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
|
|
165
201
|
or bwd_success_status == CompareConst.SPACE
|
|
@@ -261,15 +297,15 @@ class Comparator:
|
|
|
261
297
|
abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype)
|
|
262
298
|
abs_err = get_abs_err(bench_output, device_output)
|
|
263
299
|
rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps)
|
|
264
|
-
if api_name in
|
|
300
|
+
if api_name in thousandth_standard_api:
|
|
265
301
|
thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
|
|
266
302
|
compare_column.rel_err_thousandth = thousand_res
|
|
267
303
|
if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST:
|
|
268
304
|
both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output)
|
|
269
|
-
if api_name in
|
|
305
|
+
if api_name in binary_standard_api:
|
|
270
306
|
err_rate, _, _ = self._compare_bool_tensor(bench_output, device_output)
|
|
271
307
|
compare_column.error_rate = err_rate
|
|
272
|
-
elif api_name in
|
|
308
|
+
elif api_name in absolute_standard_api:
|
|
273
309
|
small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute(
|
|
274
310
|
api_name, str(dtype))
|
|
275
311
|
rel_err = abs_err / abs_bench_with_eps
|
|
@@ -279,7 +315,7 @@ class Comparator:
|
|
|
279
315
|
dtype, rtol)
|
|
280
316
|
compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol)
|
|
281
317
|
compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol)
|
|
282
|
-
elif api_name in
|
|
318
|
+
elif api_name in ulp_standard_api:
|
|
283
319
|
if bench_output.size == 0:
|
|
284
320
|
compare_column.max_ulp_error = 0
|
|
285
321
|
compare_column.mean_ulp_error = 0
|