mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,2081 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import stat
|
|
4
|
+
import time
|
|
5
|
+
from enum import Enum, auto
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
import csv
|
|
8
|
+
import random
|
|
9
|
+
|
|
10
|
+
import gc
|
|
11
|
+
import sys
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
import mindspore
|
|
14
|
+
from mindspore import ops
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
from tabulate import tabulate
|
|
18
|
+
|
|
19
|
+
import logging
|
|
20
|
+
|
|
21
|
+
import traceback
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def error_log_with_exp(self, msg: str, exp: Exception):
|
|
26
|
+
"""
|
|
27
|
+
msg: 你的错误提示
|
|
28
|
+
exp: 你要记录的 Exception 实例
|
|
29
|
+
"""
|
|
30
|
+
# 将 Exception 的类型、消息和 traceback 通过 exc_info 参数一并传给 .error()
|
|
31
|
+
self.error(msg, exc_info=(type(exp), exp, exp.__traceback__))
|
|
32
|
+
|
|
33
|
+
# 把它挂到 Logger 上
|
|
34
|
+
logging.Logger.error_log_with_exp = error_log_with_exp
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# 1. 基本配置:设置日志级别为 INFO,默认输出到控制台
|
|
39
|
+
logging.basicConfig(level=logging.INFO,
|
|
40
|
+
format='%(asctime)s [%(levelname)s] %(message)s',
|
|
41
|
+
datefmt='%H:%M:%S')
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# ======= 常数类 =======
|
|
47
|
+
|
|
48
|
+
class CodedException(Exception):
|
|
49
|
+
def __init__(self, code, error_info=''):
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.code = code
|
|
52
|
+
self.error_info = self.err_strs.get(code) + error_info
|
|
53
|
+
|
|
54
|
+
def __str__(self):
|
|
55
|
+
return self.error_info
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ApiAccuracyCheckerException(CodedException):
|
|
59
|
+
ParseJsonFailed = 0
|
|
60
|
+
UnsupportType = 1
|
|
61
|
+
WrongValue = 2
|
|
62
|
+
ApiWrong = 3
|
|
63
|
+
err_strs = {
|
|
64
|
+
ParseJsonFailed: "[msprobe] Api Accuracy Checker parse json failed: ",
|
|
65
|
+
UnsupportType: "[msprobe] Api Accuracy Checker get unsupported type: ",
|
|
66
|
+
WrongValue: "[msprobe] Api Accuracy Checker get wrong value: ",
|
|
67
|
+
ApiWrong: "[msprobe] Api Accuracy Checker something wrong with api: ",
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class FileCheckConst:
|
|
72
|
+
"""
|
|
73
|
+
Class for file check const
|
|
74
|
+
"""
|
|
75
|
+
READ_ABLE = "read"
|
|
76
|
+
WRITE_ABLE = "write"
|
|
77
|
+
READ_WRITE_ABLE = "read and write"
|
|
78
|
+
DIRECTORY_LENGTH = 4096
|
|
79
|
+
FILE_NAME_LENGTH = 255
|
|
80
|
+
FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$"
|
|
81
|
+
FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
|
|
82
|
+
PKL_SUFFIX = ".pkl"
|
|
83
|
+
NUMPY_SUFFIX = ".npy"
|
|
84
|
+
JSON_SUFFIX = ".json"
|
|
85
|
+
PT_SUFFIX = ".pt"
|
|
86
|
+
CSV_SUFFIX = ".csv"
|
|
87
|
+
XLSX_SUFFIX = ".xlsx"
|
|
88
|
+
YAML_SUFFIX = ".yaml"
|
|
89
|
+
IR_SUFFIX = ".ir"
|
|
90
|
+
ZIP_SUFFIX = ".zip"
|
|
91
|
+
SHELL_SUFFIX = ".sh"
|
|
92
|
+
MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
93
|
+
MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
|
|
94
|
+
MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
95
|
+
MAX_PT_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
|
|
96
|
+
MAX_CSV_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
97
|
+
MAX_XLSX_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
98
|
+
MAX_YAML_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
99
|
+
MAX_IR_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
100
|
+
MAX_ZIP_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
|
|
101
|
+
MAX_FILE_IN_ZIP_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
102
|
+
COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
|
|
103
|
+
DIR = "dir"
|
|
104
|
+
FILE = "file"
|
|
105
|
+
DATA_DIR_AUTHORITY = 0o750
|
|
106
|
+
DATA_FILE_AUTHORITY = 0o640
|
|
107
|
+
FILE_SIZE_DICT = {
|
|
108
|
+
PKL_SUFFIX: MAX_PKL_SIZE,
|
|
109
|
+
NUMPY_SUFFIX: MAX_NUMPY_SIZE,
|
|
110
|
+
JSON_SUFFIX: MAX_JSON_SIZE,
|
|
111
|
+
PT_SUFFIX: MAX_PT_SIZE,
|
|
112
|
+
CSV_SUFFIX: MAX_CSV_SIZE,
|
|
113
|
+
XLSX_SUFFIX: MAX_XLSX_SIZE,
|
|
114
|
+
YAML_SUFFIX: MAX_YAML_SIZE,
|
|
115
|
+
IR_SUFFIX: MAX_IR_SIZE,
|
|
116
|
+
ZIP_SUFFIX: MAX_ZIP_SIZE
|
|
117
|
+
}
|
|
118
|
+
CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]'
|
|
119
|
+
|
|
120
|
+
class Const:
|
|
121
|
+
MAX_DEPTH = 10
|
|
122
|
+
PT_FRAMEWORK = "pytorch"
|
|
123
|
+
MS_FRAMEWORK = "mindspore"
|
|
124
|
+
MT_FRAMEWORK = "mindtorch"
|
|
125
|
+
SEP = "."
|
|
126
|
+
KWARGS = 'kwargs'
|
|
127
|
+
INPUT = 'input'
|
|
128
|
+
OUTPUT = 'output'
|
|
129
|
+
INPUT_ARGS = 'input_args'
|
|
130
|
+
INPUT_KWARGS = 'input_kwargs'
|
|
131
|
+
GRAD_INPUT = 'grad_input'
|
|
132
|
+
GRAD_OUTPUT = 'grad_output'
|
|
133
|
+
BACKWARD = 'backward'
|
|
134
|
+
FORWARD = 'forward'
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class CompareConst:
|
|
138
|
+
# compare result data
|
|
139
|
+
PASS = 'pass'
|
|
140
|
+
WARNING = 'Warning'
|
|
141
|
+
ERROR = 'error'
|
|
142
|
+
TRUE = 'TRUE'
|
|
143
|
+
FALSE = 'FALSE'
|
|
144
|
+
SKIP = 'SKIP'
|
|
145
|
+
|
|
146
|
+
# compare result column name
|
|
147
|
+
COSINE = "Cosine"
|
|
148
|
+
EUC_DIST = "EucDist"
|
|
149
|
+
MAX_ABS_ERR = "MaxAbsErr"
|
|
150
|
+
MAX_RELATIVE_ERR = "MaxRelativeErr"
|
|
151
|
+
MIN_RELATIVE_ERR = "MinRelativeErr"
|
|
152
|
+
MEAN_RELATIVE_ERR = "MeanRelativeErr"
|
|
153
|
+
NORM_RELATIVE_ERR = "NormRelativeErr"
|
|
154
|
+
|
|
155
|
+
# accuracy standards
|
|
156
|
+
COS_THRESHOLD = 0.99
|
|
157
|
+
MAX_ABS_ERR_THRESHOLD = 0.001
|
|
158
|
+
MAX_RELATIVE_ERR_THRESHOLD = 0.001
|
|
159
|
+
COS_MAX_THRESHOLD = 0.9
|
|
160
|
+
MAX_ABS_ERR_MAX_THRESHOLD = 1
|
|
161
|
+
|
|
162
|
+
class MsCompareConst:
|
|
163
|
+
# api_info field
|
|
164
|
+
MINT = "Mint"
|
|
165
|
+
MINT_FUNCTIONAL = "MintFunctional"
|
|
166
|
+
TENSOR_API = "Tensor"
|
|
167
|
+
FUNCTIONAL_API = "Functional"
|
|
168
|
+
FUSION_API = "FUSION"
|
|
169
|
+
|
|
170
|
+
API_NAME_STR_LENGTH = 4
|
|
171
|
+
MAX_RECURSION_DEPTH = 20
|
|
172
|
+
|
|
173
|
+
# Mindtorch api_info field
|
|
174
|
+
MINDTORCH_TENSOR = "Tensor"
|
|
175
|
+
MINDTORCH = "Torch"
|
|
176
|
+
MINDTORCH_FUNC = "Functional"
|
|
177
|
+
MINDTORCH_NPU = "NPU"
|
|
178
|
+
MINDTORCH_DIST = "Distributed"
|
|
179
|
+
|
|
180
|
+
MT_VALID_API_TYPES = [
|
|
181
|
+
MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR
|
|
182
|
+
]
|
|
183
|
+
SUPPORTED_FUSION_LIST = ["flash_attention_score"]
|
|
184
|
+
|
|
185
|
+
TASK_FIELD = "task"
|
|
186
|
+
STATISTICS_TASK = "statistics"
|
|
187
|
+
FRAMEWORK = "framework"
|
|
188
|
+
TENSOR_TASK = "tensor"
|
|
189
|
+
DUMP_DATA_DIR_FIELD = "dump_data_dir"
|
|
190
|
+
DATA_FIELD = "data"
|
|
191
|
+
|
|
192
|
+
# supported api yaml
|
|
193
|
+
SUPPORTED_API_LIST_FILE = "checker_support_api.yaml"
|
|
194
|
+
SUPPORTED_TENSOR_LIST_KEY = "tensor"
|
|
195
|
+
|
|
196
|
+
# detail_csv
|
|
197
|
+
DETAIL_CSV_API_NAME = "API Name"
|
|
198
|
+
DETAIL_CSV_BENCH_DTYPE = "Bench Dtype"
|
|
199
|
+
DETAIL_CSV_TESTED_DTYPE = "Tested Dtype"
|
|
200
|
+
DETAIL_CSV_SHAPE = "Shape"
|
|
201
|
+
DETAIL_CSV_PASS_STATUS = "Status"
|
|
202
|
+
DETAIL_CSV_MESSAGE = "Message"
|
|
203
|
+
DETAIL_CSV_FILE_NAME = "accuracy_checking_details"
|
|
204
|
+
|
|
205
|
+
# result_csv
|
|
206
|
+
RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success"
|
|
207
|
+
RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success"
|
|
208
|
+
RESULT_CSV_FILE_NAME = "accuracy_checking_result"
|
|
209
|
+
|
|
210
|
+
EPSILON = 1e-8
|
|
211
|
+
|
|
212
|
+
class ProcessStatus:
|
|
213
|
+
SUCCESS = "success"
|
|
214
|
+
API_NOT_FOUND = "api_not_found"
|
|
215
|
+
EXCEPTION_SKIP = "exception_skip"
|
|
216
|
+
|
|
217
|
+
# ======= mindtorch支持 ========
|
|
218
|
+
|
|
219
|
+
import torch as mindtorch
|
|
220
|
+
from torch import Tensor as mindtorch_tensor
|
|
221
|
+
import torch.nn.functional as mindtorch_func
|
|
222
|
+
import torch.distributed as mindtorch_dist
|
|
223
|
+
|
|
224
|
+
is_valid_pt_mt_env = True
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def is_mindtorch():
|
|
228
|
+
mindtorch_check_result = False
|
|
229
|
+
try:
|
|
230
|
+
import torch as test_torch
|
|
231
|
+
from mindspore import Tensor as MindsporeTensor
|
|
232
|
+
except ImportError:
|
|
233
|
+
return mindtorch_check_result
|
|
234
|
+
tensor = test_torch.tensor(0.0)
|
|
235
|
+
if isinstance(tensor, MindsporeTensor):
|
|
236
|
+
mindtorch_check_result = True
|
|
237
|
+
|
|
238
|
+
return mindtorch_check_result
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def remove_torch_related_paths():
|
|
242
|
+
removed_paths = []
|
|
243
|
+
if not is_mindtorch():
|
|
244
|
+
return
|
|
245
|
+
try:
|
|
246
|
+
import torch as remove_torch
|
|
247
|
+
torch_file = remove_torch.__file__
|
|
248
|
+
except ImportError:
|
|
249
|
+
return
|
|
250
|
+
|
|
251
|
+
torch_dir = os.path.dirname(torch_file)
|
|
252
|
+
|
|
253
|
+
torch_dir_path = Path(torch_dir).resolve()
|
|
254
|
+
parent_dir = torch_dir_path.parent
|
|
255
|
+
|
|
256
|
+
paths_to_remove = [str(parent_dir)]
|
|
257
|
+
|
|
258
|
+
for path in paths_to_remove:
|
|
259
|
+
try:
|
|
260
|
+
path_resolved = str(Path(path).resolve())
|
|
261
|
+
except Exception as error:
|
|
262
|
+
logger.debug(f"Failed to resolve path {path}: {error}")
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
if path_resolved in sys.path:
|
|
266
|
+
index = sys.path.index(path_resolved)
|
|
267
|
+
removed_paths.append((path_resolved, index))
|
|
268
|
+
sys.path.pop(index)
|
|
269
|
+
|
|
270
|
+
return
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def clear_torch_from_sys_modules():
|
|
274
|
+
modules_to_remove = []
|
|
275
|
+
for module in sys.modules:
|
|
276
|
+
if module == "torch" or module.startswith("torch."):
|
|
277
|
+
modules_to_remove.append(module)
|
|
278
|
+
|
|
279
|
+
for module in modules_to_remove:
|
|
280
|
+
del sys.modules[module]
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def set_pt_mt_env_invalid():
|
|
284
|
+
global is_valid_pt_mt_env
|
|
285
|
+
is_valid_pt_mt_env = False
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def delete_torch_paths():
|
|
289
|
+
|
|
290
|
+
if not is_mindtorch():
|
|
291
|
+
set_pt_mt_env_invalid()
|
|
292
|
+
|
|
293
|
+
clear_torch_from_sys_modules()
|
|
294
|
+
|
|
295
|
+
for count_delete_env_path in range(MsCompareConst.MAX_RECURSION_DEPTH):
|
|
296
|
+
if not is_mindtorch():
|
|
297
|
+
break
|
|
298
|
+
|
|
299
|
+
remove_torch_related_paths()
|
|
300
|
+
|
|
301
|
+
clear_torch_from_sys_modules()
|
|
302
|
+
|
|
303
|
+
if count_delete_env_path >= MsCompareConst.MAX_RECURSION_DEPTH - 1:
|
|
304
|
+
raise Exception(f"Please check if you have a valid PyTorch and MindTorch environment, and ensure "
|
|
305
|
+
f"the PYTHONPATH environment variable depth does not exceed {Const.MAX_RECURSION_DEPTH}.")
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
if not is_mindtorch():
|
|
309
|
+
set_pt_mt_env_invalid()
|
|
310
|
+
|
|
311
|
+
else:
|
|
312
|
+
initial_sys_path = sys.path.copy()
|
|
313
|
+
delete_torch_paths()
|
|
314
|
+
|
|
315
|
+
gc.collect()
|
|
316
|
+
|
|
317
|
+
import torch
|
|
318
|
+
|
|
319
|
+
if is_mindtorch():
|
|
320
|
+
set_pt_mt_env_invalid()
|
|
321
|
+
|
|
322
|
+
sys.path = initial_sys_path
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
if not is_valid_pt_mt_env:
|
|
327
|
+
import torch
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
# ======= 常数类 =======
|
|
332
|
+
|
|
333
|
+
import numpy as np
|
|
334
|
+
from mindspore._c_expression import typing
|
|
335
|
+
from mindspore.common import dtype as mstype
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
|
|
339
|
+
TORCH_BOOL_TYPE = ["torch.bool"]
|
|
340
|
+
TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int",
|
|
341
|
+
"torch.int64", "torch.long"]
|
|
342
|
+
TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float",
|
|
343
|
+
"torch.float64", "torch.double"]
|
|
344
|
+
TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"]
|
|
345
|
+
RAISE_PRECISION = {{
|
|
346
|
+
"torch.float16": torch.float32,
|
|
347
|
+
"torch.half": torch.float32,
|
|
348
|
+
"torch.bfloat16": torch.float32,
|
|
349
|
+
"torch.float32": torch.float64,
|
|
350
|
+
"torch.float": torch.float64
|
|
351
|
+
}}
|
|
352
|
+
THOUSANDTH_THRESHOLDING = 0.001
|
|
353
|
+
BACKWARD = 'backward'
|
|
354
|
+
DIR = "dir"
|
|
355
|
+
FILE = "file"
|
|
356
|
+
READ_ABLE = "read"
|
|
357
|
+
WRITE_ABLE = "write"
|
|
358
|
+
READ_WRITE_ABLE = "read and write"
|
|
359
|
+
DIRECTORY_LENGTH = 4096
|
|
360
|
+
FILE_NAME_LENGTH = 255
|
|
361
|
+
SOFT_LINK_ERROR = "检测到软链接"
|
|
362
|
+
FILE_PERMISSION_ERROR = "文件权限错误"
|
|
363
|
+
INVALID_FILE_ERROR = "无效文件"
|
|
364
|
+
ILLEGAL_PATH_ERROR = "非法文件路径"
|
|
365
|
+
ILLEGAL_PARAM_ERROR = "非法打开方式"
|
|
366
|
+
FILE_TOO_LARGE_ERROR = "文件过大"
|
|
367
|
+
FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$"
|
|
368
|
+
FILE_SIZE_DICT = {{
|
|
369
|
+
".pkl": 1073741824, # 1 * 1024 * 1024 * 1024
|
|
370
|
+
".npy": 10737418240, # 10 * 1024 * 1024 * 1024
|
|
371
|
+
".json": 1073741824, # 1 * 1024 * 1024 * 1024
|
|
372
|
+
".pt": 10737418240, # 10 * 1024 * 1024 * 1024
|
|
373
|
+
".csv": 1073741824, # 1 * 1024 * 1024 * 1024
|
|
374
|
+
".xlsx": 1073741824, # 1 * 1024 * 1024 * 1024
|
|
375
|
+
".yaml": 1073741824, # 1 * 1024 * 1024 * 1024
|
|
376
|
+
".ir": 1073741824 # 1 * 1024 * 1024 * 1024
|
|
377
|
+
}}
|
|
378
|
+
COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
INT8 = "Int8"
|
|
382
|
+
UINT8 = "UInt8"
|
|
383
|
+
INT16 = "Int16"
|
|
384
|
+
UINT16 = "UInt16"
|
|
385
|
+
INT32 = "Int32"
|
|
386
|
+
UINT32 = "UInt32"
|
|
387
|
+
INT64 = "Int64"
|
|
388
|
+
UINT64 = "UInt64"
|
|
389
|
+
FLOAT16 = "Float16"
|
|
390
|
+
FLOAT32 = "Float32"
|
|
391
|
+
FLOAT64 = "Float64"
|
|
392
|
+
BOOL = "Bool"
|
|
393
|
+
BFLOAT16 = "BFloat16"
|
|
394
|
+
INT4 = "Int4"
|
|
395
|
+
|
|
396
|
+
dtype_str_to_ms_dtype = {
|
|
397
|
+
INT8: mstype.int8,
|
|
398
|
+
UINT8: mstype.uint8,
|
|
399
|
+
INT16: mstype.int16,
|
|
400
|
+
UINT16: mstype.uint16,
|
|
401
|
+
INT32: mstype.int32,
|
|
402
|
+
UINT32: mstype.uint32,
|
|
403
|
+
INT64: mstype.int64,
|
|
404
|
+
UINT64: mstype.uint64,
|
|
405
|
+
FLOAT16: mstype.float16,
|
|
406
|
+
FLOAT32: mstype.float32,
|
|
407
|
+
FLOAT64: mstype.float64,
|
|
408
|
+
BOOL: mstype.bool_,
|
|
409
|
+
BFLOAT16: mstype.bfloat16,
|
|
410
|
+
INT4: mstype.qint4x2
|
|
411
|
+
}
|
|
412
|
+
ms_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_ms_dtype.items()}
|
|
413
|
+
|
|
414
|
+
dtype_str_to_np_dtype = {
|
|
415
|
+
INT8: np.int8,
|
|
416
|
+
UINT8: np.uint8,
|
|
417
|
+
INT16: np.int16,
|
|
418
|
+
UINT16: np.uint16,
|
|
419
|
+
INT32: np.int32,
|
|
420
|
+
UINT32: np.uint32,
|
|
421
|
+
INT64: np.int64,
|
|
422
|
+
UINT64: np.uint64,
|
|
423
|
+
FLOAT16: np.float16,
|
|
424
|
+
FLOAT32: np.float32,
|
|
425
|
+
FLOAT64: np.float64,
|
|
426
|
+
BOOL: np.bool_
|
|
427
|
+
}
|
|
428
|
+
np_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_np_dtype.items()}
|
|
429
|
+
|
|
430
|
+
dtype_str_to_torch_dtype = {
|
|
431
|
+
INT8: torch.int8,
|
|
432
|
+
UINT8: torch.uint8,
|
|
433
|
+
INT16: torch.int16,
|
|
434
|
+
INT32: torch.int32,
|
|
435
|
+
INT64: torch.int64,
|
|
436
|
+
FLOAT16: torch.float16,
|
|
437
|
+
FLOAT32: torch.float32,
|
|
438
|
+
FLOAT64: torch.float64,
|
|
439
|
+
BOOL: torch.bool,
|
|
440
|
+
BFLOAT16: torch.bfloat16,
|
|
441
|
+
}
|
|
442
|
+
torch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_torch_dtype.items()}
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
dtype_str_to_mindtorch_dtype = {
|
|
446
|
+
INT8: mindtorch.int8,
|
|
447
|
+
UINT8: mindtorch.uint8,
|
|
448
|
+
INT16: mindtorch.int16,
|
|
449
|
+
INT32: mindtorch.int32,
|
|
450
|
+
INT64: mindtorch.int64,
|
|
451
|
+
FLOAT16: mindtorch.float16,
|
|
452
|
+
FLOAT32: mindtorch.float32,
|
|
453
|
+
FLOAT64: mindtorch.float64,
|
|
454
|
+
BOOL: mindtorch.bool,
|
|
455
|
+
BFLOAT16: mindtorch.bfloat16,
|
|
456
|
+
}
|
|
457
|
+
mindtorch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_mindtorch_dtype.items()}
|
|
458
|
+
|
|
459
|
+
MINDSPORE_TENSOR_TYPE_STR = "mindspore.Tensor"
|
|
460
|
+
BOOL_TYPE_STR = "bool"
|
|
461
|
+
INT_TYPE_STR = "int"
|
|
462
|
+
FLOAT_TYPE_STR = "float"
|
|
463
|
+
SLICE_TYPE_STR = "slice"
|
|
464
|
+
TUPLE_TYPE_STR = "tuple"
|
|
465
|
+
STR_TYPE_STR = "str"
|
|
466
|
+
MINDSPORE_DTYPE_TYPE_STR = "mindspore.dtype"
|
|
467
|
+
TORCH_DTYPE_TYPE_STR = "torch.dtype"
|
|
468
|
+
|
|
469
|
+
api_info_type_str_to_type = {
|
|
470
|
+
MINDSPORE_TENSOR_TYPE_STR: mindspore.Tensor,
|
|
471
|
+
BOOL_TYPE_STR: bool,
|
|
472
|
+
INT_TYPE_STR: int,
|
|
473
|
+
FLOAT_TYPE_STR: float,
|
|
474
|
+
SLICE_TYPE_STR: slice,
|
|
475
|
+
STR_TYPE_STR: str,
|
|
476
|
+
MINDSPORE_DTYPE_TYPE_STR: typing.Type,
|
|
477
|
+
}
|
|
478
|
+
type_to_api_info_type_str = {value: key for key, value in api_info_type_str_to_type.items()}
|
|
479
|
+
|
|
480
|
+
DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE = np.float64
|
|
481
|
+
DEFAULT_CONSTRUCT_NP_INT_DTYPE = np.float64
|
|
482
|
+
DEFAULT_CONSTRUCT_NP_UINT_DTYPE = np.float64
|
|
483
|
+
|
|
484
|
+
float_dtype_str_list = [
|
|
485
|
+
FLOAT16,
|
|
486
|
+
FLOAT32,
|
|
487
|
+
FLOAT64,
|
|
488
|
+
BFLOAT16,
|
|
489
|
+
]
|
|
490
|
+
|
|
491
|
+
int_dtype_str_list = [
|
|
492
|
+
INT8,
|
|
493
|
+
INT16,
|
|
494
|
+
INT32,
|
|
495
|
+
INT64,
|
|
496
|
+
BOOL,
|
|
497
|
+
INT4,
|
|
498
|
+
]
|
|
499
|
+
|
|
500
|
+
uint_dtype_str_list = [
|
|
501
|
+
UINT8,
|
|
502
|
+
UINT16,
|
|
503
|
+
UINT32,
|
|
504
|
+
UINT64,
|
|
505
|
+
]
|
|
506
|
+
|
|
507
|
+
# ======= 比对类 =======
|
|
508
|
+
|
|
509
|
+
class CompareResult:
|
|
510
|
+
def __init__(self, compare_value, pass_status, err_msg):
|
|
511
|
+
self.compare_value = compare_value
|
|
512
|
+
self.pass_status = pass_status
|
|
513
|
+
self.err_msg = err_msg
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
class BaseCompareAlgorithm(ABC):
|
|
517
|
+
def __init__(self) -> None:
|
|
518
|
+
super().__init__()
|
|
519
|
+
self.compare_algorithm_name = None
|
|
520
|
+
self.err_msg_mapping = {
|
|
521
|
+
CompareConst.COSINE: {
|
|
522
|
+
CompareConst.PASS: "",
|
|
523
|
+
CompareConst.ERROR: f"cosine similarity is less than threshold: {CompareConst.COS_THRESHOLD} ",
|
|
524
|
+
CompareConst.SKIP: "two inputs are not valid for computing cosine similarity, skip comparing ",
|
|
525
|
+
},
|
|
526
|
+
CompareConst.MAX_ABS_ERR: {
|
|
527
|
+
CompareConst.PASS: "",
|
|
528
|
+
CompareConst.ERROR: "max absolute difference is greater than " \
|
|
529
|
+
f"threshold: {CompareConst.MAX_ABS_ERR_THRESHOLD} ",
|
|
530
|
+
CompareConst.SKIP: "two inputs are not valid for computing max absolute difference, skip comparing ",
|
|
531
|
+
},
|
|
532
|
+
CompareConst.MAX_RELATIVE_ERR: {
|
|
533
|
+
CompareConst.PASS: "",
|
|
534
|
+
CompareConst.ERROR: "",
|
|
535
|
+
CompareConst.SKIP: "",
|
|
536
|
+
},
|
|
537
|
+
}
|
|
538
|
+
|
|
539
|
+
def __call__(self, bench_compute_element, tested_compute_element):
|
|
540
|
+
'''
|
|
541
|
+
Args:
|
|
542
|
+
bench_compute_element: ComputeElement
|
|
543
|
+
tested_compute_element: ComputeElement
|
|
544
|
+
|
|
545
|
+
Return:
|
|
546
|
+
compare_result: CompareResult
|
|
547
|
+
'''
|
|
548
|
+
if self.check_validity(bench_compute_element, tested_compute_element):
|
|
549
|
+
compare_value = self.run_compare(bench_compute_element, tested_compute_element)
|
|
550
|
+
pass_status = self.check_pass(compare_value)
|
|
551
|
+
else:
|
|
552
|
+
logger.warning(f"not suitable for computing {self.compare_algorithm_name}, skip this.")
|
|
553
|
+
compare_value = None
|
|
554
|
+
pass_status = CompareConst.SKIP
|
|
555
|
+
|
|
556
|
+
err_msg = self.err_msg_mapping.get(self.compare_algorithm_name).get(pass_status)
|
|
557
|
+
|
|
558
|
+
compare_result = CompareResult(compare_value, pass_status, err_msg)
|
|
559
|
+
return compare_result
|
|
560
|
+
|
|
561
|
+
@staticmethod
|
|
562
|
+
def convert_to_np_float64_ndarray(tensor):
|
|
563
|
+
if isinstance(tensor, mindspore.Tensor):
|
|
564
|
+
ndarray = tensor.astype(mindspore.float64).numpy()
|
|
565
|
+
elif isinstance(tensor, torch.Tensor):
|
|
566
|
+
ndarray = tensor.to(torch.float64, copy=True).numpy()
|
|
567
|
+
else:
|
|
568
|
+
err_msg = "BaseCompareAlgorithm.convert_to_np_float64_ndarray failed: " \
|
|
569
|
+
"input is not mindspore.Tensor or torch.Tensor"
|
|
570
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
571
|
+
return ndarray
|
|
572
|
+
|
|
573
|
+
@staticmethod
|
|
574
|
+
def check_two_tensor(bench_compute_element, tested_compute_element):
|
|
575
|
+
bench_parameter = bench_compute_element.get_parameter()
|
|
576
|
+
tested_parameter = tested_compute_element.get_parameter()
|
|
577
|
+
|
|
578
|
+
bench_is_tensor = isinstance(bench_parameter, (mindspore.Tensor, torch.Tensor))
|
|
579
|
+
tested_is_tensor = isinstance(tested_parameter, (mindspore.Tensor, torch.Tensor))
|
|
580
|
+
shape_same = bench_compute_element.get_shape() == tested_compute_element.get_shape()
|
|
581
|
+
return bench_is_tensor and tested_is_tensor and shape_same
|
|
582
|
+
|
|
583
|
+
@abstractmethod
|
|
584
|
+
def check_validity(self, bench_compute_element, tested_compute_element):
|
|
585
|
+
'''
|
|
586
|
+
Args:
|
|
587
|
+
bench_compute_element: ComputeElement
|
|
588
|
+
tested_compute_element: ComputeElement
|
|
589
|
+
|
|
590
|
+
Return:
|
|
591
|
+
check_res: boolean
|
|
592
|
+
'''
|
|
593
|
+
raise NotImplementedError
|
|
594
|
+
|
|
595
|
+
@abstractmethod
|
|
596
|
+
def run_compare(self, bench_compute_element, tested_compute_element):
|
|
597
|
+
'''
|
|
598
|
+
Args:
|
|
599
|
+
bench_compute_element: ComputeElement
|
|
600
|
+
tested_compute_element: ComputeElement
|
|
601
|
+
|
|
602
|
+
Return:
|
|
603
|
+
compare_value: float/int
|
|
604
|
+
'''
|
|
605
|
+
raise NotImplementedError
|
|
606
|
+
|
|
607
|
+
@abstractmethod
|
|
608
|
+
def check_pass(self, compare_value):
|
|
609
|
+
'''
|
|
610
|
+
Args:
|
|
611
|
+
compare_value: float/int
|
|
612
|
+
|
|
613
|
+
Return:
|
|
614
|
+
pass_status: str
|
|
615
|
+
'''
|
|
616
|
+
raise NotImplementedError
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
class CosineSimilarityCompareAlgorithm(BaseCompareAlgorithm):
|
|
620
|
+
def __init__(self) -> None:
|
|
621
|
+
super().__init__()
|
|
622
|
+
self.compare_algorithm_name = CompareConst.COSINE
|
|
623
|
+
|
|
624
|
+
def check_validity(self, bench_compute_element, tested_compute_element):
|
|
625
|
+
return self.check_two_tensor(bench_compute_element, tested_compute_element)
|
|
626
|
+
|
|
627
|
+
def run_compare(self, bench_compute_element, tested_compute_element):
|
|
628
|
+
bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
|
|
629
|
+
tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
|
|
630
|
+
|
|
631
|
+
bench_norm = np.linalg.norm(bench_ndarray)
|
|
632
|
+
tested_norm = np.linalg.norm(tested_ndarray)
|
|
633
|
+
dot_product = np.dot(bench_ndarray.flatten(), tested_ndarray.flatten())
|
|
634
|
+
cosine_similarity = (MsCompareConst.EPSILON + dot_product) / (MsCompareConst.EPSILON + bench_norm * tested_norm)
|
|
635
|
+
return cosine_similarity
|
|
636
|
+
|
|
637
|
+
def check_pass(self, compare_value):
|
|
638
|
+
if compare_value > CompareConst.COS_THRESHOLD:
|
|
639
|
+
return CompareConst.PASS
|
|
640
|
+
else:
|
|
641
|
+
return CompareConst.ERROR
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
class MaxAbsoluteDiffCompareAlgorithm(BaseCompareAlgorithm):
|
|
645
|
+
def __init__(self) -> None:
|
|
646
|
+
super().__init__()
|
|
647
|
+
self.compare_algorithm_name = CompareConst.MAX_ABS_ERR
|
|
648
|
+
|
|
649
|
+
def check_validity(self, bench_compute_element, tested_compute_element):
|
|
650
|
+
return self.check_two_tensor(bench_compute_element, tested_compute_element)
|
|
651
|
+
|
|
652
|
+
def run_compare(self, bench_compute_element, tested_compute_element):
|
|
653
|
+
bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
|
|
654
|
+
tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
|
|
655
|
+
|
|
656
|
+
max_absolute_diff = np.max(np.abs(bench_ndarray - tested_ndarray))
|
|
657
|
+
return max_absolute_diff
|
|
658
|
+
|
|
659
|
+
def check_pass(self, compare_value):
|
|
660
|
+
if compare_value < CompareConst.MAX_ABS_ERR_THRESHOLD:
|
|
661
|
+
return CompareConst.PASS
|
|
662
|
+
else:
|
|
663
|
+
return CompareConst.ERROR
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
class MaxRelativeDiffCompareAlgorithm(BaseCompareAlgorithm):
|
|
667
|
+
def __init__(self) -> None:
|
|
668
|
+
super().__init__()
|
|
669
|
+
self.compare_algorithm_name = CompareConst.MAX_RELATIVE_ERR
|
|
670
|
+
|
|
671
|
+
def check_validity(self, bench_compute_element, tested_compute_element):
|
|
672
|
+
return self.check_two_tensor(bench_compute_element, tested_compute_element)
|
|
673
|
+
|
|
674
|
+
def run_compare(self, bench_compute_element, tested_compute_element):
|
|
675
|
+
bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
|
|
676
|
+
tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
|
|
677
|
+
|
|
678
|
+
abs_diff = np.abs(bench_ndarray - tested_ndarray)
|
|
679
|
+
bench_ndarray_nonzero = np.abs(bench_ndarray) + (bench_ndarray == 0) * MsCompareConst.EPSILON
|
|
680
|
+
max_relative_diff = np.max(abs_diff / bench_ndarray_nonzero)
|
|
681
|
+
return max_relative_diff
|
|
682
|
+
|
|
683
|
+
def check_pass(self, compare_value):
|
|
684
|
+
if compare_value < CompareConst.MAX_RELATIVE_ERR_THRESHOLD:
|
|
685
|
+
return CompareConst.PASS
|
|
686
|
+
else:
|
|
687
|
+
return CompareConst.ERROR
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
compare_algorithms = {
|
|
691
|
+
CompareConst.COSINE: CosineSimilarityCompareAlgorithm(),
|
|
692
|
+
CompareConst.MAX_ABS_ERR: MaxAbsoluteDiffCompareAlgorithm(),
|
|
693
|
+
CompareConst.MAX_RELATIVE_ERR: MaxRelativeDiffCompareAlgorithm(),
|
|
694
|
+
}
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
class CompareStandard(Enum):
|
|
699
|
+
BINARY_EQUALITY_STANDARD = auto()
|
|
700
|
+
ABSOLUTE_THRESHOLD_STANDARD = auto()
|
|
701
|
+
ULP_ERROR_STANDARD = auto()
|
|
702
|
+
BENCHMARK_STANDARD = auto()
|
|
703
|
+
THOUSANDTH_STANDARD = auto()
|
|
704
|
+
|
|
705
|
+
|
|
706
|
+
class CompareStandard(Enum):
|
|
707
|
+
BINARY_EQUALITY_STANDARD = auto()
|
|
708
|
+
ABSOLUTE_THRESHOLD_STANDARD = auto()
|
|
709
|
+
ULP_ERROR_STANDARD = auto()
|
|
710
|
+
BENCHMARK_STANDARD = auto()
|
|
711
|
+
THOUSANDTH_STANDARD = auto()
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
# ======== 文件操作类 ==========
|
|
715
|
+
|
|
716
|
+
from collections import defaultdict
|
|
717
|
+
from functools import wraps
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_type=None, accepted_value=None):
|
|
721
|
+
'''
|
|
722
|
+
Args:
|
|
723
|
+
dict_instance: dict, dict parsed from input json
|
|
724
|
+
key: str
|
|
725
|
+
key_description: str
|
|
726
|
+
accepted_type: tuple
|
|
727
|
+
accepted_value: Union[tuple, list]
|
|
728
|
+
|
|
729
|
+
Return:
|
|
730
|
+
value, the corresponding value of "key" in "dict_instance"
|
|
731
|
+
|
|
732
|
+
Exception:
|
|
733
|
+
raise ApiAccuracyCheckerException.ParseJsonFailed error when
|
|
734
|
+
1. dict_instance is not a dict
|
|
735
|
+
2. value is None
|
|
736
|
+
3. value is not accepted type
|
|
737
|
+
4. value is not accepted value
|
|
738
|
+
'''
|
|
739
|
+
if not isinstance(dict_instance, dict):
|
|
740
|
+
error_info = "check_and_get_from_json_dict failed: input is not a dict"
|
|
741
|
+
raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
|
|
742
|
+
value = dict_instance.get(key)
|
|
743
|
+
if value is None:
|
|
744
|
+
error_info = f"check_and_get_from_json_dict failed: {key_description} is missing"
|
|
745
|
+
raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
|
|
746
|
+
elif accepted_type is not None and not isinstance(value, accepted_type):
|
|
747
|
+
error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted type: {accepted_type}"
|
|
748
|
+
raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
|
|
749
|
+
elif accepted_value is not None and value not in accepted_value:
|
|
750
|
+
error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted value: {accepted_value}"
|
|
751
|
+
raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
|
|
752
|
+
return value
|
|
753
|
+
|
|
754
|
+
|
|
755
|
+
def convert_to_tuple(args):
|
|
756
|
+
if isinstance(args, (tuple, list)):
|
|
757
|
+
return tuple(args)
|
|
758
|
+
else:
|
|
759
|
+
input_list = [args]
|
|
760
|
+
return tuple(input_list)
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
def trim_output_compute_element_list(compute_element_list, forward_or_backward):
|
|
764
|
+
'''
|
|
765
|
+
Args:
|
|
766
|
+
compute_element_list: List[ComputeElement]
|
|
767
|
+
forward_or_backward: str, Union["forward", "backward"]
|
|
768
|
+
'''
|
|
769
|
+
trimmed_list = []
|
|
770
|
+
for compute_element in compute_element_list:
|
|
771
|
+
if compute_element.get_parameter() is None or \
|
|
772
|
+
(forward_or_backward == Const.BACKWARD and compute_element.get_dtype() not in float_dtype_str_list):
|
|
773
|
+
# trim case: 1. parameter is None. 2. backward output has non float parameter
|
|
774
|
+
continue
|
|
775
|
+
trimmed_list.append(compute_element)
|
|
776
|
+
return trimmed_list
|
|
777
|
+
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
# 记录工具函数递归的深度
|
|
782
|
+
recursion_depth = defaultdict(int)
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
def recursion_depth_decorator(func_info, max_depth=Const.MAX_DEPTH):
|
|
786
|
+
"""装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。"""
|
|
787
|
+
def decorator(func):
|
|
788
|
+
@wraps(func)
|
|
789
|
+
def wrapper(*args, **kwargs):
|
|
790
|
+
func_id = id(func)
|
|
791
|
+
recursion_depth[func_id] += 1
|
|
792
|
+
|
|
793
|
+
try:
|
|
794
|
+
result = func(*args, **kwargs)
|
|
795
|
+
finally:
|
|
796
|
+
recursion_depth[func_id] -= 1
|
|
797
|
+
return result
|
|
798
|
+
|
|
799
|
+
return wrapper
|
|
800
|
+
|
|
801
|
+
return decorator
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
class FileChecker:
|
|
806
|
+
"""
|
|
807
|
+
The class for check file.
|
|
808
|
+
|
|
809
|
+
Attributes:
|
|
810
|
+
file_path: The file or dictionary path to be verified.
|
|
811
|
+
path_type: file or dictionary
|
|
812
|
+
ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability
|
|
813
|
+
file_type(str): The correct file type for file
|
|
814
|
+
"""
|
|
815
|
+
|
|
816
|
+
def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True):
|
|
817
|
+
self.file_path = file_path
|
|
818
|
+
self.path_type = self._check_path_type(path_type)
|
|
819
|
+
self.ability = ability
|
|
820
|
+
self.file_type = file_type
|
|
821
|
+
self.is_script = is_script
|
|
822
|
+
|
|
823
|
+
@staticmethod
|
|
824
|
+
def _check_path_type(path_type):
|
|
825
|
+
if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]:
|
|
826
|
+
logger.error(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.')
|
|
827
|
+
raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR)
|
|
828
|
+
return path_type
|
|
829
|
+
|
|
830
|
+
def common_check(self):
|
|
831
|
+
"""
|
|
832
|
+
功能:用户校验基本文件权限:软连接、文件长度、是否存在、读写权限、文件属组、文件特殊字符
|
|
833
|
+
注意:文件后缀的合法性,非通用操作,可使用其他独立接口实现
|
|
834
|
+
"""
|
|
835
|
+
check_path_exists(self.file_path)
|
|
836
|
+
check_link(self.file_path)
|
|
837
|
+
self.file_path = os.path.realpath(self.file_path)
|
|
838
|
+
check_path_length(self.file_path)
|
|
839
|
+
check_path_type(self.file_path, self.path_type)
|
|
840
|
+
self.check_path_ability()
|
|
841
|
+
if self.is_script:
|
|
842
|
+
check_path_owner_consistent(self.file_path)
|
|
843
|
+
check_path_pattern_valid(self.file_path)
|
|
844
|
+
check_common_file_size(self.file_path)
|
|
845
|
+
check_file_suffix(self.file_path, self.file_type)
|
|
846
|
+
if self.path_type == FileCheckConst.FILE:
|
|
847
|
+
check_dirpath_before_read(self.file_path)
|
|
848
|
+
return self.file_path
|
|
849
|
+
|
|
850
|
+
def check_path_ability(self):
|
|
851
|
+
if self.ability == FileCheckConst.WRITE_ABLE:
|
|
852
|
+
check_path_writability(self.file_path)
|
|
853
|
+
if self.ability == FileCheckConst.READ_ABLE:
|
|
854
|
+
check_path_readability(self.file_path)
|
|
855
|
+
if self.ability == FileCheckConst.READ_WRITE_ABLE:
|
|
856
|
+
check_path_readability(self.file_path)
|
|
857
|
+
check_path_writability(self.file_path)
|
|
858
|
+
|
|
859
|
+
|
|
860
|
+
class FileOpen:
|
|
861
|
+
"""
|
|
862
|
+
The class for open file by a safe way.
|
|
863
|
+
|
|
864
|
+
Attributes:
|
|
865
|
+
file_path: The file or dictionary path to be opened.
|
|
866
|
+
mode(str): The file open mode
|
|
867
|
+
"""
|
|
868
|
+
SUPPORT_READ_MODE = ["r", "rb"]
|
|
869
|
+
SUPPORT_WRITE_MODE = ["w", "wb", "a", "ab"]
|
|
870
|
+
SUPPORT_READ_WRITE_MODE = ["r+", "rb+", "w+", "wb+", "a+", "ab+"]
|
|
871
|
+
|
|
872
|
+
def __init__(self, file_path, mode, encoding='utf-8'):
|
|
873
|
+
self.file_path = file_path
|
|
874
|
+
self.mode = mode
|
|
875
|
+
self.encoding = encoding
|
|
876
|
+
self._handle = None
|
|
877
|
+
|
|
878
|
+
def __enter__(self):
|
|
879
|
+
self.check_file_path()
|
|
880
|
+
binary_mode = "b"
|
|
881
|
+
if binary_mode not in self.mode:
|
|
882
|
+
self._handle = open(self.file_path, self.mode, encoding=self.encoding)
|
|
883
|
+
else:
|
|
884
|
+
self._handle = open(self.file_path, self.mode)
|
|
885
|
+
return self._handle
|
|
886
|
+
|
|
887
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
888
|
+
if self._handle:
|
|
889
|
+
self._handle.close()
|
|
890
|
+
|
|
891
|
+
def check_file_path(self):
|
|
892
|
+
support_mode = self.SUPPORT_READ_MODE + self.SUPPORT_WRITE_MODE + self.SUPPORT_READ_WRITE_MODE
|
|
893
|
+
if self.mode not in support_mode:
|
|
894
|
+
logger.error("File open not support %s mode" % self.mode)
|
|
895
|
+
raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR)
|
|
896
|
+
check_link(self.file_path)
|
|
897
|
+
self.file_path = os.path.realpath(self.file_path)
|
|
898
|
+
check_path_length(self.file_path)
|
|
899
|
+
self.check_ability_and_owner()
|
|
900
|
+
check_path_pattern_valid(self.file_path)
|
|
901
|
+
if os.path.exists(self.file_path):
|
|
902
|
+
check_common_file_size(self.file_path)
|
|
903
|
+
check_dirpath_before_read(self.file_path)
|
|
904
|
+
|
|
905
|
+
def check_ability_and_owner(self):
|
|
906
|
+
if self.mode in self.SUPPORT_READ_MODE:
|
|
907
|
+
check_path_exists(self.file_path)
|
|
908
|
+
check_path_readability(self.file_path)
|
|
909
|
+
check_path_owner_consistent(self.file_path)
|
|
910
|
+
if self.mode in self.SUPPORT_WRITE_MODE and os.path.exists(self.file_path):
|
|
911
|
+
check_path_writability(self.file_path)
|
|
912
|
+
check_path_owner_consistent(self.file_path)
|
|
913
|
+
if self.mode in self.SUPPORT_READ_WRITE_MODE and os.path.exists(self.file_path):
|
|
914
|
+
check_path_readability(self.file_path)
|
|
915
|
+
check_path_writability(self.file_path)
|
|
916
|
+
check_path_owner_consistent(self.file_path)
|
|
917
|
+
|
|
918
|
+
|
|
919
|
+
def check_link(path):
|
|
920
|
+
abs_path = os.path.abspath(path)
|
|
921
|
+
if os.path.islink(abs_path):
|
|
922
|
+
logger.error('The file path {} is a soft link.'.format(path))
|
|
923
|
+
raise FileCheckException(FileCheckException.SOFT_LINK_ERROR)
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
def check_path_length(path, name_length=None):
|
|
927
|
+
file_max_name_length = name_length if name_length else FileCheckConst.FILE_NAME_LENGTH
|
|
928
|
+
if len(path) > FileCheckConst.DIRECTORY_LENGTH or \
|
|
929
|
+
len(os.path.basename(path)) > file_max_name_length:
|
|
930
|
+
logger.error('The file path length exceeds limit.')
|
|
931
|
+
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
def check_path_exists(path):
|
|
935
|
+
if not os.path.exists(path):
|
|
936
|
+
logger.error('The file path %s does not exist.' % path)
|
|
937
|
+
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
|
|
938
|
+
|
|
939
|
+
|
|
940
|
+
def check_path_readability(path):
|
|
941
|
+
if not os.access(path, os.R_OK):
|
|
942
|
+
logger.error('The file path %s is not readable.' % path)
|
|
943
|
+
raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
|
|
944
|
+
|
|
945
|
+
|
|
946
|
+
def check_path_writability(path):
|
|
947
|
+
if not os.access(path, os.W_OK):
|
|
948
|
+
logger.error('The file path %s is not writable.' % path)
|
|
949
|
+
raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
|
|
950
|
+
|
|
951
|
+
|
|
952
|
+
def check_path_executable(path):
|
|
953
|
+
if not os.access(path, os.X_OK):
|
|
954
|
+
logger.error('The file path %s is not executable.' % path)
|
|
955
|
+
raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
|
|
956
|
+
|
|
957
|
+
|
|
958
|
+
def check_other_user_writable(path):
|
|
959
|
+
st = os.stat(path)
|
|
960
|
+
if st.st_mode & 0o002:
|
|
961
|
+
logger.error('The file path %s may be insecure because other users have write permissions. ' % path)
|
|
962
|
+
raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
|
|
963
|
+
|
|
964
|
+
|
|
965
|
+
def check_path_owner_consistent(path):
|
|
966
|
+
file_owner = os.stat(path).st_uid
|
|
967
|
+
if file_owner != os.getuid() and os.getuid() != 0:
|
|
968
|
+
logger.error('The file path %s may be insecure because is does not belong to you.' % path)
|
|
969
|
+
raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
|
|
970
|
+
|
|
971
|
+
|
|
972
|
+
def check_path_pattern_valid(path):
|
|
973
|
+
if not re.match(FileCheckConst.FILE_VALID_PATTERN, path):
|
|
974
|
+
logger.error('The file path %s contains special characters.' % (path))
|
|
975
|
+
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
|
|
976
|
+
|
|
977
|
+
|
|
978
|
+
def check_file_size(file_path, max_size):
|
|
979
|
+
try:
|
|
980
|
+
file_size = os.path.getsize(file_path)
|
|
981
|
+
except OSError as os_error:
|
|
982
|
+
logger.error(f'Failed to open "{file_path}". {str(os_error)}')
|
|
983
|
+
raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) from os_error
|
|
984
|
+
if file_size >= max_size:
|
|
985
|
+
logger.error(f'The size ({file_size}) of {file_path} exceeds ({max_size}) bytes, tools not support.')
|
|
986
|
+
raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR)
|
|
987
|
+
|
|
988
|
+
|
|
989
|
+
def check_common_file_size(file_path):
|
|
990
|
+
if os.path.isfile(file_path):
|
|
991
|
+
for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items():
|
|
992
|
+
if file_path.endswith(suffix):
|
|
993
|
+
check_file_size(file_path, max_size)
|
|
994
|
+
return
|
|
995
|
+
check_file_size(file_path, FileCheckConst.COMMOM_FILE_SIZE)
|
|
996
|
+
|
|
997
|
+
|
|
998
|
+
def check_file_suffix(file_path, file_suffix):
|
|
999
|
+
if file_suffix:
|
|
1000
|
+
if not file_path.endswith(file_suffix):
|
|
1001
|
+
logger.error(f"The {file_path} should be a {file_suffix} file!")
|
|
1002
|
+
raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
|
|
1003
|
+
|
|
1004
|
+
|
|
1005
|
+
def check_path_type(file_path, file_type):
|
|
1006
|
+
if file_type == FileCheckConst.FILE:
|
|
1007
|
+
if not os.path.isfile(file_path):
|
|
1008
|
+
logger.error(f"The {file_path} should be a file!")
|
|
1009
|
+
raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
|
|
1010
|
+
if file_type == FileCheckConst.DIR:
|
|
1011
|
+
if not os.path.isdir(file_path):
|
|
1012
|
+
logger.error(f"The {file_path} should be a dictionary!")
|
|
1013
|
+
raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
|
|
1014
|
+
|
|
1015
|
+
def make_dir(dir_path):
|
|
1016
|
+
check_path_before_create(dir_path)
|
|
1017
|
+
dir_path = os.path.realpath(dir_path)
|
|
1018
|
+
if os.path.isdir(dir_path):
|
|
1019
|
+
return
|
|
1020
|
+
try:
|
|
1021
|
+
os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True)
|
|
1022
|
+
except OSError as ex:
|
|
1023
|
+
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR,
|
|
1024
|
+
f"Failed to create {dir_path}. "
|
|
1025
|
+
f"Please check the path permission or disk space. {str(ex)}") from ex
|
|
1026
|
+
file_check = FileChecker(dir_path, FileCheckConst.DIR)
|
|
1027
|
+
file_check.common_check()
|
|
1028
|
+
|
|
1029
|
+
|
|
1030
|
+
|
|
1031
|
+
|
|
1032
|
+
@recursion_depth_decorator('msprobe.core.common.file_utils.create_directory', max_depth=16)
|
|
1033
|
+
def create_directory(dir_path):
|
|
1034
|
+
"""
|
|
1035
|
+
Function Description:
|
|
1036
|
+
creating a safe directory with specified permissions
|
|
1037
|
+
Parameter:
|
|
1038
|
+
dir_path: directory path
|
|
1039
|
+
Exception Description:
|
|
1040
|
+
when invalid data throw exception
|
|
1041
|
+
"""
|
|
1042
|
+
check_link(dir_path)
|
|
1043
|
+
check_path_before_create(dir_path)
|
|
1044
|
+
dir_path = os.path.realpath(dir_path)
|
|
1045
|
+
parent_dir = os.path.dirname(dir_path)
|
|
1046
|
+
if not os.path.isdir(parent_dir):
|
|
1047
|
+
create_directory(parent_dir)
|
|
1048
|
+
make_dir(dir_path)
|
|
1049
|
+
|
|
1050
|
+
|
|
1051
|
+
def check_path_before_create(path):
|
|
1052
|
+
check_link(path)
|
|
1053
|
+
if path_len_exceeds_limit(path):
|
|
1054
|
+
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, 'The file path length exceeds limit.')
|
|
1055
|
+
|
|
1056
|
+
if not re.match(FileCheckConst.FILE_PATTERN, os.path.realpath(path)):
|
|
1057
|
+
raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR,
|
|
1058
|
+
'The file path {} contains special characters.'.format(path))
|
|
1059
|
+
|
|
1060
|
+
|
|
1061
|
+
def check_dirpath_before_read(path):
|
|
1062
|
+
path = os.path.realpath(path)
|
|
1063
|
+
dirpath = os.path.dirname(path)
|
|
1064
|
+
|
|
1065
|
+
|
|
1066
|
+
def check_file_or_directory_path(path, isdir=False):
|
|
1067
|
+
"""
|
|
1068
|
+
Function Description:
|
|
1069
|
+
check whether the path is valid
|
|
1070
|
+
Parameter:
|
|
1071
|
+
path: the path to check
|
|
1072
|
+
isdir: the path is dir or file
|
|
1073
|
+
Exception Description:
|
|
1074
|
+
when invalid data throw exception
|
|
1075
|
+
"""
|
|
1076
|
+
if isdir:
|
|
1077
|
+
path_checker = FileChecker(path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE)
|
|
1078
|
+
else:
|
|
1079
|
+
path_checker = FileChecker(path, FileCheckConst.FILE, FileCheckConst.READ_ABLE)
|
|
1080
|
+
path_checker.common_check()
|
|
1081
|
+
|
|
1082
|
+
|
|
1083
|
+
def change_mode(path, mode):
|
|
1084
|
+
if not os.path.exists(path) or os.path.islink(path):
|
|
1085
|
+
return
|
|
1086
|
+
try:
|
|
1087
|
+
os.chmod(path, mode)
|
|
1088
|
+
except PermissionError as ex:
|
|
1089
|
+
raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR,
|
|
1090
|
+
'Failed to change {} authority. {}'.format(path, str(ex))) from ex
|
|
1091
|
+
|
|
1092
|
+
|
|
1093
|
+
def path_len_exceeds_limit(file_path):
|
|
1094
|
+
return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \
|
|
1095
|
+
len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH
|
|
1096
|
+
|
|
1097
|
+
def load_npy(filepath):
|
|
1098
|
+
check_file_or_directory_path(filepath)
|
|
1099
|
+
try:
|
|
1100
|
+
npy = np.load(filepath, allow_pickle=False)
|
|
1101
|
+
except Exception as e:
|
|
1102
|
+
logger.error(f"The numpy file failed to load. Please check the path: {filepath}.")
|
|
1103
|
+
raise RuntimeError(f"Load numpy file {filepath} failed.") from e
|
|
1104
|
+
return npy
|
|
1105
|
+
|
|
1106
|
+
def write_csv(data, filepath, mode="a+", malicious_check=False):
|
|
1107
|
+
def csv_value_is_valid(value: str) -> bool:
|
|
1108
|
+
if not isinstance(value, str):
|
|
1109
|
+
return True
|
|
1110
|
+
try:
|
|
1111
|
+
# -1.00 or +1.00 should be considered as digit numbers
|
|
1112
|
+
float(value)
|
|
1113
|
+
except ValueError:
|
|
1114
|
+
# otherwise, they will be considered as formular injections
|
|
1115
|
+
return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
|
|
1116
|
+
return True
|
|
1117
|
+
|
|
1118
|
+
if malicious_check:
|
|
1119
|
+
for row in data:
|
|
1120
|
+
for cell in row:
|
|
1121
|
+
if not csv_value_is_valid(cell):
|
|
1122
|
+
raise RuntimeError(f"Malicious value [{cell}] is not allowed "
|
|
1123
|
+
f"to be written into the csv: {filepath}.")
|
|
1124
|
+
|
|
1125
|
+
check_path_before_create(filepath)
|
|
1126
|
+
file_path = os.path.realpath(filepath)
|
|
1127
|
+
try:
|
|
1128
|
+
with FileOpen(filepath, mode, encoding='utf-8-sig') as f:
|
|
1129
|
+
writer = csv.writer(f)
|
|
1130
|
+
writer.writerows(data)
|
|
1131
|
+
except Exception as e:
|
|
1132
|
+
logger.error(f'Save csv file "{os.path.basename(file_path)}" failed')
|
|
1133
|
+
raise RuntimeError(f"Save csv file {file_path} failed.") from e
|
|
1134
|
+
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
1135
|
+
print(f"file_path:{file_path}")
|
|
1136
|
+
|
|
1137
|
+
|
|
1138
|
+
|
|
1139
|
+
def write_csv_header(csv_path, header_func):
|
|
1140
|
+
"""如果是第一次写入,则写入 CSV 表头"""
|
|
1141
|
+
header = header_func() # 获取表头
|
|
1142
|
+
logger.debug(f"Writing CSV header: {header}")
|
|
1143
|
+
write_csv([header], csv_path, mode="a+")
|
|
1144
|
+
|
|
1145
|
+
|
|
1146
|
+
def get_result_csv_header():
|
|
1147
|
+
"""获取结果 CSV 文件的表头"""
|
|
1148
|
+
return [
|
|
1149
|
+
MsCompareConst.DETAIL_CSV_API_NAME,
|
|
1150
|
+
MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
|
|
1151
|
+
MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
|
|
1152
|
+
MsCompareConst.DETAIL_CSV_MESSAGE,
|
|
1153
|
+
]
|
|
1154
|
+
|
|
1155
|
+
|
|
1156
|
+
def get_detail_csv_header():
|
|
1157
|
+
"""获取详细 CSV 文件的表头"""
|
|
1158
|
+
detail_csv_header_basic_info = [
|
|
1159
|
+
MsCompareConst.DETAIL_CSV_API_NAME,
|
|
1160
|
+
MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
|
|
1161
|
+
MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
|
|
1162
|
+
MsCompareConst.DETAIL_CSV_SHAPE,
|
|
1163
|
+
]
|
|
1164
|
+
detail_csv_header_compare_result = list(compare_algorithms.keys())
|
|
1165
|
+
detail_csv_header_status = [
|
|
1166
|
+
MsCompareConst.DETAIL_CSV_PASS_STATUS,
|
|
1167
|
+
MsCompareConst.DETAIL_CSV_MESSAGE,
|
|
1168
|
+
]
|
|
1169
|
+
return detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
|
|
1170
|
+
|
|
1171
|
+
|
|
1172
|
+
def check_csv_header(headers, required_constants, csv_path):
|
|
1173
|
+
"""校验 CSV 文件表头是否包含所有必需的常量"""
|
|
1174
|
+
missing_constants = [const for const in required_constants if not any(const in header for header in headers)]
|
|
1175
|
+
|
|
1176
|
+
if missing_constants:
|
|
1177
|
+
raise MsprobeBaseException(
|
|
1178
|
+
MsprobeBaseException.MISSING_HEADER_ERROR,
|
|
1179
|
+
f"{csv_path} 缺少以下必需的表头字段: {missing_constants}"
|
|
1180
|
+
)
|
|
1181
|
+
def add_time_as_suffix(name):
|
|
1182
|
+
return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
1183
|
+
|
|
1184
|
+
|
|
1185
|
+
# ======= 结果落盘管理类 ========
|
|
1186
|
+
|
|
1187
|
+
class DataManager:
|
|
1188
|
+
def __init__(self, csv_dir, result_csv_path):
|
|
1189
|
+
self.results = {}
|
|
1190
|
+
self.results_exception_skip = {}
|
|
1191
|
+
self.is_first_write = True # 标记用于添加表头
|
|
1192
|
+
self.csv_dir = csv_dir
|
|
1193
|
+
self.api_names_set = set() # 存储已经出现的 API 名称的集合
|
|
1194
|
+
# 如果传入了 result_csv_path,则启用断点续检
|
|
1195
|
+
if result_csv_path:
|
|
1196
|
+
self.resume_from_last_csv(result_csv_path)
|
|
1197
|
+
self.initialize_api_names_set(result_csv_path)
|
|
1198
|
+
else:
|
|
1199
|
+
# 默认情况下,设置输出路径为空,等待首次写入时初始化
|
|
1200
|
+
self.result_out_path = os.path.join(self.csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
|
|
1201
|
+
self.detail_out_path = os.path.join(
|
|
1202
|
+
self.csv_dir,
|
|
1203
|
+
os.path.basename(self.result_out_path).replace("result", "details")
|
|
1204
|
+
)
|
|
1205
|
+
|
|
1206
|
+
if self.detail_out_path and os.path.exists(self.detail_out_path):
|
|
1207
|
+
check_file_or_directory_path(self.detail_out_path)
|
|
1208
|
+
|
|
1209
|
+
if self.result_out_path and os.path.exists(self.result_out_path):
|
|
1210
|
+
check_file_or_directory_path(self.result_out_path)
|
|
1211
|
+
|
|
1212
|
+
def initialize_api_names_set(self, result_csv_path):
|
|
1213
|
+
"""读取现有的 CSV 文件并存储已经出现的 API 名称到集合中"""
|
|
1214
|
+
# 使用新的 read_csv 函数读取数据
|
|
1215
|
+
csv_data = read_csv(result_csv_path, as_pd=False)
|
|
1216
|
+
|
|
1217
|
+
# 读取标题行
|
|
1218
|
+
headers = csv_data[0] if csv_data else [] # 如果文件为空,则 headers 会为空
|
|
1219
|
+
|
|
1220
|
+
# 使用提取的表头校验函数
|
|
1221
|
+
if check_csv_header(headers, get_result_csv_header(), result_csv_path):
|
|
1222
|
+
|
|
1223
|
+
# 获取 "API Name" 列的索引
|
|
1224
|
+
api_name_index = None
|
|
1225
|
+
for i, header in enumerate(headers):
|
|
1226
|
+
if MsCompareConst.DETAIL_CSV_API_NAME in header: # CSV 文件的标题行包含了字节顺序标记,所以使用通过包含方式来查找
|
|
1227
|
+
api_name_index = i
|
|
1228
|
+
break
|
|
1229
|
+
|
|
1230
|
+
if api_name_index is None:
|
|
1231
|
+
logger.warning(f"{result_csv_path} No column contains 'API Name'.")
|
|
1232
|
+
return
|
|
1233
|
+
|
|
1234
|
+
# 读取每一行的 API 名称
|
|
1235
|
+
for row in csv_data[1:]: # 跳过标题行,从第二行开始
|
|
1236
|
+
if row and len(row) > api_name_index:
|
|
1237
|
+
api_name = row[api_name_index]
|
|
1238
|
+
if api_name:
|
|
1239
|
+
self.api_names_set.add(api_name)
|
|
1240
|
+
|
|
1241
|
+
logger.debug(f"Initialized API names set from existing CSV: {self.api_names_set}")
|
|
1242
|
+
|
|
1243
|
+
def is_unique_api(self, api_name):
|
|
1244
|
+
"""检查 API 名称是否唯一,如果已经存在则返回 False,否则加入集合并返回 True"""
|
|
1245
|
+
if api_name in self.api_names_set:
|
|
1246
|
+
return False
|
|
1247
|
+
self.api_names_set.add(api_name)
|
|
1248
|
+
return True
|
|
1249
|
+
|
|
1250
|
+
def resume_from_last_csv(self, result_csv_path):
|
|
1251
|
+
"""从上次运行的 result_csv_path 恢复断点"""
|
|
1252
|
+
# 获取上次的目录路径
|
|
1253
|
+
last_dir = os.path.dirname(result_csv_path)
|
|
1254
|
+
|
|
1255
|
+
# 设置当前目录和输出路径,确保在首次写入时使用
|
|
1256
|
+
self.csv_dir = last_dir
|
|
1257
|
+
self.detail_out_path = os.path.join(last_dir, os.path.basename(result_csv_path).replace("result", "details"))
|
|
1258
|
+
if self.detail_out_path and os.path.exists(self.detail_out_path):
|
|
1259
|
+
check_file_or_directory_path(self.detail_out_path)
|
|
1260
|
+
self.result_out_path = result_csv_path
|
|
1261
|
+
self.is_first_write = False
|
|
1262
|
+
|
|
1263
|
+
def save_results(self, api_name_str):
|
|
1264
|
+
if self.is_first_write:
|
|
1265
|
+
# 直接写入表头
|
|
1266
|
+
logger.info("Writing CSV headers for the first time.")
|
|
1267
|
+
write_csv_header(self.detail_out_path, get_detail_csv_header)
|
|
1268
|
+
write_csv_header(self.result_out_path, get_result_csv_header)
|
|
1269
|
+
self.is_first_write = False # 写入后标记为 False,避免重复写入表头
|
|
1270
|
+
|
|
1271
|
+
"""写入详细输出和结果摘要并清理结果"""
|
|
1272
|
+
logger.debug("Starting to write detailed output to CSV.")
|
|
1273
|
+
self.to_detail_csv(self.detail_out_path)
|
|
1274
|
+
logger.debug(f"Detailed output for {api_name_str} written to {self.detail_out_path}.")
|
|
1275
|
+
|
|
1276
|
+
logger.debug("Starting to write result summary to CSV.")
|
|
1277
|
+
self.to_result_csv(self.result_out_path)
|
|
1278
|
+
logger.debug(f"Result summary for {api_name_str} written to {self.result_out_path}.")
|
|
1279
|
+
|
|
1280
|
+
# 清理记录,准备下一次调用
|
|
1281
|
+
self.clear_results()
|
|
1282
|
+
|
|
1283
|
+
def record(self, output_list):
|
|
1284
|
+
if output_list is None:
|
|
1285
|
+
return
|
|
1286
|
+
for output in output_list:
|
|
1287
|
+
api_real_name, forward_or_backward, basic_info, compare_result_dict = output
|
|
1288
|
+
key = (api_real_name, forward_or_backward)
|
|
1289
|
+
if key not in self.results:
|
|
1290
|
+
self.results[key] = []
|
|
1291
|
+
self.results[key].append((basic_info, compare_result_dict))
|
|
1292
|
+
logger.debug(f"Complete self.results after recording: {self.results}")
|
|
1293
|
+
|
|
1294
|
+
def record_exception_skip(self, api_name, forward_or_backward, err_msg):
|
|
1295
|
+
'''
|
|
1296
|
+
record exception_skip information into self.record_exception_skip.
|
|
1297
|
+
self.record_exception_skip: dict{str: dict{"forward": str/None, "backward": str/None}}
|
|
1298
|
+
string in key is api_name, string in value is err_msg
|
|
1299
|
+
'''
|
|
1300
|
+
if api_name not in self.results_exception_skip:
|
|
1301
|
+
self.results_exception_skip[api_name] = {Const.FORWARD: None, Const.BACKWARD: None}
|
|
1302
|
+
self.results_exception_skip[api_name][forward_or_backward] = err_msg
|
|
1303
|
+
|
|
1304
|
+
def clear_results(self):
|
|
1305
|
+
"""清空 self.results 数据"""
|
|
1306
|
+
logger.debug("Clearing self.results data.")
|
|
1307
|
+
self.results.clear()
|
|
1308
|
+
self.results_exception_skip.clear()
|
|
1309
|
+
|
|
1310
|
+
def to_detail_csv(self, csv_path):
|
|
1311
|
+
logger.debug("Preparing detail CSV headers and rows.")
|
|
1312
|
+
detail_csv = []
|
|
1313
|
+
|
|
1314
|
+
detail_csv_header_compare_result = list(compare_algorithms.keys())
|
|
1315
|
+
|
|
1316
|
+
for _, results in self.results.items():
|
|
1317
|
+
for res in results:
|
|
1318
|
+
basic_info, compare_result_dict = res
|
|
1319
|
+
csv_row_basic_info = [
|
|
1320
|
+
basic_info.api_name,
|
|
1321
|
+
basic_info.bench_dtype,
|
|
1322
|
+
basic_info.tested_dtype,
|
|
1323
|
+
basic_info.shape
|
|
1324
|
+
]
|
|
1325
|
+
csv_row_compare_result = [
|
|
1326
|
+
compare_result_dict.get(algorithm_name).compare_value
|
|
1327
|
+
for algorithm_name in detail_csv_header_compare_result
|
|
1328
|
+
]
|
|
1329
|
+
csv_row_status = [basic_info.status, basic_info.err_msg]
|
|
1330
|
+
csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
|
|
1331
|
+
detail_csv.append(csv_row)
|
|
1332
|
+
logger.debug(f"Detail CSV row added: {csv_row}")
|
|
1333
|
+
|
|
1334
|
+
logger.debug(f"Writing detail CSV to {csv_path}.")
|
|
1335
|
+
write_csv(detail_csv, csv_path, mode="a+")
|
|
1336
|
+
logger.debug(f"Detail CSV written successfully to {csv_path}.")
|
|
1337
|
+
|
|
1338
|
+
def to_result_csv(self, csv_path):
|
|
1339
|
+
'''
|
|
1340
|
+
depend on both self.results and self.results_exception_skip
|
|
1341
|
+
'''
|
|
1342
|
+
logger.debug("Preparing result CSV data.")
|
|
1343
|
+
result_csv = []
|
|
1344
|
+
|
|
1345
|
+
result_csv_dict = {}
|
|
1346
|
+
for key, results in self.results.items():
|
|
1347
|
+
api_real_name, forward_or_backward = key
|
|
1348
|
+
pass_status = CompareConst.PASS
|
|
1349
|
+
overall_err_msg = ""
|
|
1350
|
+
|
|
1351
|
+
for res in results:
|
|
1352
|
+
basic_info, _ = res
|
|
1353
|
+
if basic_info.status != CompareConst.PASS:
|
|
1354
|
+
pass_status = CompareConst.ERROR
|
|
1355
|
+
overall_err_msg += basic_info.err_msg
|
|
1356
|
+
|
|
1357
|
+
overall_err_msg = "" if pass_status == CompareConst.PASS else overall_err_msg
|
|
1358
|
+
|
|
1359
|
+
if api_real_name not in result_csv_dict:
|
|
1360
|
+
result_csv_dict[api_real_name] = ResultCsvEntry()
|
|
1361
|
+
if forward_or_backward == Const.FORWARD:
|
|
1362
|
+
result_csv_dict[api_real_name].forward_pass_status = pass_status
|
|
1363
|
+
result_csv_dict[api_real_name].forward_err_msg = overall_err_msg
|
|
1364
|
+
else:
|
|
1365
|
+
result_csv_dict[api_real_name].backward_pass_status = pass_status
|
|
1366
|
+
result_csv_dict[api_real_name].backward_err_msg = overall_err_msg
|
|
1367
|
+
|
|
1368
|
+
for api_name, entry in result_csv_dict.items():
|
|
1369
|
+
overall_err_msg = "" if (entry.forward_pass_status == CompareConst.PASS and
|
|
1370
|
+
entry.backward_pass_status == CompareConst.PASS) else \
|
|
1371
|
+
entry.forward_err_msg + entry.backward_err_msg
|
|
1372
|
+
row = [
|
|
1373
|
+
api_name,
|
|
1374
|
+
entry.forward_pass_status,
|
|
1375
|
+
entry.backward_pass_status,
|
|
1376
|
+
overall_err_msg
|
|
1377
|
+
]
|
|
1378
|
+
# change row if this api has exception_skip information
|
|
1379
|
+
if api_name in self.results_exception_skip:
|
|
1380
|
+
if self.results_exception_skip[api_name][Const.FORWARD] is not None:
|
|
1381
|
+
row[1] = CompareConst.SKIP
|
|
1382
|
+
row[-1] += self.results_exception_skip[api_name][Const.FORWARD]
|
|
1383
|
+
if self.results_exception_skip[api_name][Const.BACKWARD] is not None:
|
|
1384
|
+
row[2] = CompareConst.SKIP
|
|
1385
|
+
row[-1] += self.results_exception_skip[api_name][Const.BACKWARD]
|
|
1386
|
+
del self.results_exception_skip[api_name]
|
|
1387
|
+
result_csv.append(row)
|
|
1388
|
+
logger.debug(f"Result CSV row added: {row}")
|
|
1389
|
+
for api_name in self.results_exception_skip:
|
|
1390
|
+
current_exception_skip = self.results_exception_skip[api_name]
|
|
1391
|
+
forward_status = None
|
|
1392
|
+
backward_status = None
|
|
1393
|
+
err_msg = ""
|
|
1394
|
+
if current_exception_skip[Const.FORWARD] is not None:
|
|
1395
|
+
forward_status = CompareConst.SKIP
|
|
1396
|
+
err_msg += current_exception_skip[Const.FORWARD]
|
|
1397
|
+
if current_exception_skip[Const.BACKWARD] is not None:
|
|
1398
|
+
backward_status = CompareConst.SKIP
|
|
1399
|
+
err_msg += current_exception_skip[Const.BACKWARD]
|
|
1400
|
+
row = [api_name, forward_status, backward_status, err_msg]
|
|
1401
|
+
result_csv.append(row)
|
|
1402
|
+
|
|
1403
|
+
write_csv(result_csv, csv_path, mode="a+")
|
|
1404
|
+
logger.debug(f"Result CSV written successfully to {csv_path}.")
|
|
1405
|
+
|
|
1406
|
+
# 设置标记为 False,防止后续重复添加表头
|
|
1407
|
+
self.is_first_write = False
|
|
1408
|
+
|
|
1409
|
+
# ======== 全局变量类 =======
|
|
1410
|
+
|
|
1411
|
+
class GlobalContext:
|
|
1412
|
+
def __init__(self):
|
|
1413
|
+
self.is_constructed = True
|
|
1414
|
+
self.dump_data_dir = ""
|
|
1415
|
+
self.framework = Const.MS_FRAMEWORK
|
|
1416
|
+
|
|
1417
|
+
def init(self, is_constructed, dump_data_dir, framework):
|
|
1418
|
+
self.is_constructed = is_constructed
|
|
1419
|
+
self.dump_data_dir = dump_data_dir
|
|
1420
|
+
self.framework = framework
|
|
1421
|
+
|
|
1422
|
+
def get_dump_data_dir(self):
|
|
1423
|
+
return self.dump_data_dir
|
|
1424
|
+
|
|
1425
|
+
def get_is_constructed(self):
|
|
1426
|
+
return self.is_constructed
|
|
1427
|
+
|
|
1428
|
+
def get_framework(self):
|
|
1429
|
+
return self.framework
|
|
1430
|
+
|
|
1431
|
+
|
|
1432
|
+
global_context = GlobalContext()
|
|
1433
|
+
|
|
1434
|
+
# ======== 输入类型类 =======
|
|
1435
|
+
|
|
1436
|
+
def seed_all(seed={random_seed}):
|
|
1437
|
+
random.seed(seed)
|
|
1438
|
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
1439
|
+
np.random.seed(seed)
|
|
1440
|
+
torch.manual_seed(seed)
|
|
1441
|
+
torch.use_deterministic_algorithms(True)
|
|
1442
|
+
mindtorch.manual_seed(seed)
|
|
1443
|
+
mindtorch.use_deterministic_algorithms(True)
|
|
1444
|
+
mindspore.set_deterministic(True)
|
|
1445
|
+
|
|
1446
|
+
class ApiInputAggregation:
|
|
1447
|
+
def __init__(self, inputs, kwargs, gradient_inputs) -> None:
|
|
1448
|
+
"""
|
|
1449
|
+
Args:
|
|
1450
|
+
inputs: List[ComputeElement]
|
|
1451
|
+
kwargs: dict{str: ComputeElement}
|
|
1452
|
+
gradient_inputs: Union[List[ComputeElement], None]
|
|
1453
|
+
"""
|
|
1454
|
+
self.inputs = inputs
|
|
1455
|
+
self.kwargs = kwargs
|
|
1456
|
+
self.gradient_inputs = gradient_inputs
|
|
1457
|
+
|
|
1458
|
+
|
|
1459
|
+
api_parent_module_mapping = {
|
|
1460
|
+
(MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint,
|
|
1461
|
+
(MsCompareConst.MINT, Const.PT_FRAMEWORK): torch,
|
|
1462
|
+
(MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
|
|
1463
|
+
(MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional,
|
|
1464
|
+
(MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): mindspore.Tensor,
|
|
1465
|
+
(MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor,
|
|
1466
|
+
(MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): mindtorch_tensor,
|
|
1467
|
+
(MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): torch.Tensor,
|
|
1468
|
+
(MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): mindtorch,
|
|
1469
|
+
(MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): torch,
|
|
1470
|
+
(MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): mindtorch_func,
|
|
1471
|
+
(MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): torch.nn.functional,
|
|
1472
|
+
(MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): mindtorch_dist,
|
|
1473
|
+
(MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed,
|
|
1474
|
+
(MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): mindspore.ops
|
|
1475
|
+
|
|
1476
|
+
}
|
|
1477
|
+
|
|
1478
|
+
|
|
1479
|
+
api_parent_module_str_mapping = {
|
|
1480
|
+
(MsCompareConst.MINT, Const.MS_FRAMEWORK): "mindspore.mint",
|
|
1481
|
+
(MsCompareConst.MINT, Const.PT_FRAMEWORK): "torch",
|
|
1482
|
+
(MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): "mindspore.mint.nn.functional",
|
|
1483
|
+
(MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): "torch.nn.functional",
|
|
1484
|
+
(MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): "mindspore.Tensor",
|
|
1485
|
+
(MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor",
|
|
1486
|
+
(MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): "mindtorch_tensor",
|
|
1487
|
+
(MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): "torch.Tensor",
|
|
1488
|
+
(MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): "mindtorch",
|
|
1489
|
+
(MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): "torch",
|
|
1490
|
+
(MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): "mindtorch_func",
|
|
1491
|
+
(MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): "torch.nn.functional",
|
|
1492
|
+
(MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): "mindtorch_dist",
|
|
1493
|
+
(MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed",
|
|
1494
|
+
(MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): "mindspore.ops"
|
|
1495
|
+
}
|
|
1496
|
+
|
|
1497
|
+
|
|
1498
|
+
class ApiRunner:
|
|
1499
|
+
def __call__(self, api_input_aggregation, api_name_str, forward_or_backward=Const.FORWARD,
|
|
1500
|
+
api_platform=Const.MS_FRAMEWORK):
|
|
1501
|
+
'''
|
|
1502
|
+
Args:
|
|
1503
|
+
api_input_aggregation: ApiInputAggregation
|
|
1504
|
+
api_name_str: str, e.g. "MintFunctional.relu.0"
|
|
1505
|
+
forward_or_backward: str, Union["forward", "backward"]
|
|
1506
|
+
api_platform: str, Union["mindspore", "torch", "mindtorch"]
|
|
1507
|
+
|
|
1508
|
+
Return:
|
|
1509
|
+
outputs: list[ComputeElement]
|
|
1510
|
+
|
|
1511
|
+
Description:
|
|
1512
|
+
run mindspore.mint/torch api
|
|
1513
|
+
'''
|
|
1514
|
+
|
|
1515
|
+
api_type_str, api_sub_name = self.get_info_from_name(api_name_str, api_platform)
|
|
1516
|
+
api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform)
|
|
1517
|
+
|
|
1518
|
+
return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform)
|
|
1519
|
+
|
|
1520
|
+
@staticmethod
|
|
1521
|
+
def get_info_from_name(api_name_str, api_platform=Const.MS_FRAMEWORK):
|
|
1522
|
+
"""
|
|
1523
|
+
Args:
|
|
1524
|
+
api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
|
|
1525
|
+
api_platform: str, the platform for the API, which can be either "mindspore" or "mindtorch".
|
|
1526
|
+
It specifies which framework is being used. Default is "mindspore".
|
|
1527
|
+
Return:
|
|
1528
|
+
api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Torch", "Functional"]
|
|
1529
|
+
api_sub_name: str, e.g. "relu"
|
|
1530
|
+
"""
|
|
1531
|
+
api_name_list = api_name_str.split(Const.SEP)
|
|
1532
|
+
if len(api_name_list) != 3:
|
|
1533
|
+
err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
|
|
1534
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
1535
|
+
api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
|
|
1536
|
+
if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API,
|
|
1537
|
+
MsCompareConst.FUNCTIONAL_API] \
|
|
1538
|
+
and api_platform == Const.MS_FRAMEWORK:
|
|
1539
|
+
err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api"
|
|
1540
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
1541
|
+
|
|
1542
|
+
if api_type_str not in MsCompareConst.MT_VALID_API_TYPES and api_platform == Const.MT_FRAMEWORK:
|
|
1543
|
+
err_msg = f"ApiRunner.get_info_from_name failed: not torch, functional or Tensor api"
|
|
1544
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
1545
|
+
return api_type_str, api_sub_name
|
|
1546
|
+
|
|
1547
|
+
@staticmethod
|
|
1548
|
+
def get_api_instance(api_type_str, api_sub_name, api_platform):
|
|
1549
|
+
"""
|
|
1550
|
+
Args:
|
|
1551
|
+
api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Functional"]
|
|
1552
|
+
api_sub_name: str, e.g. "relu"
|
|
1553
|
+
api_platform: str: Union["mindspore", "pytorch"]
|
|
1554
|
+
|
|
1555
|
+
Return:
|
|
1556
|
+
api_instance: function object
|
|
1557
|
+
|
|
1558
|
+
Description:
|
|
1559
|
+
get mindspore.mint/torch api function
|
|
1560
|
+
mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
|
|
1561
|
+
mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
|
|
1562
|
+
"""
|
|
1563
|
+
|
|
1564
|
+
api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
|
|
1565
|
+
api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform))
|
|
1566
|
+
full_api_name = api_parent_module_str + Const.SEP + api_sub_name
|
|
1567
|
+
|
|
1568
|
+
if not hasattr(api_parent_module, api_sub_name):
|
|
1569
|
+
err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
|
|
1570
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
|
|
1571
|
+
|
|
1572
|
+
api_instance = getattr(api_parent_module, api_sub_name)
|
|
1573
|
+
if not callable(api_instance):
|
|
1574
|
+
err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not callable"
|
|
1575
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
|
|
1576
|
+
|
|
1577
|
+
return api_instance
|
|
1578
|
+
|
|
1579
|
+
@staticmethod
|
|
1580
|
+
def run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform):
|
|
1581
|
+
inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
|
|
1582
|
+
for compute_element in api_input_aggregation.inputs)
|
|
1583
|
+
kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform)
|
|
1584
|
+
for key, value in api_input_aggregation.kwargs.items()}
|
|
1585
|
+
gradient_inputs = api_input_aggregation.gradient_inputs
|
|
1586
|
+
|
|
1587
|
+
if forward_or_backward == Const.FORWARD:
|
|
1588
|
+
forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
|
|
1589
|
+
forward_result_tuple = convert_to_tuple(forward_result)
|
|
1590
|
+
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
|
|
1591
|
+
if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK:
|
|
1592
|
+
return res_compute_element_list, inputs, kwargs, forward_result_tuple
|
|
1593
|
+
else:
|
|
1594
|
+
if gradient_inputs is None:
|
|
1595
|
+
err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing"
|
|
1596
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
1597
|
+
gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
|
|
1598
|
+
for compute_element in gradient_inputs)
|
|
1599
|
+
if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK:
|
|
1600
|
+
if len(gradient_inputs) == 1:
|
|
1601
|
+
gradient_inputs = gradient_inputs[0]
|
|
1602
|
+
|
|
1603
|
+
def api_with_kwargs(*forward_inputs):
|
|
1604
|
+
return api_instance(*forward_inputs, **kwargs)
|
|
1605
|
+
|
|
1606
|
+
grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs)
|
|
1607
|
+
backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
|
|
1608
|
+
backward_result_tuple = convert_to_tuple(backward_result)
|
|
1609
|
+
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
|
|
1610
|
+
return res_compute_element_list, gradient_inputs, backward_result_tuple
|
|
1611
|
+
else:
|
|
1612
|
+
# set requires_grad
|
|
1613
|
+
requires_grad_index = []
|
|
1614
|
+
for index, tensor in enumerate(inputs):
|
|
1615
|
+
if isinstance(tensor, torch.Tensor) and \
|
|
1616
|
+
torch_dtype_to_dtype_str.get(tensor.dtype) in float_dtype_str_list:
|
|
1617
|
+
setattr(tensor, "requires_grad", True)
|
|
1618
|
+
requires_grad_index.append(index)
|
|
1619
|
+
forward_results = api_instance(*inputs, **kwargs)
|
|
1620
|
+
forward_results = convert_to_tuple(forward_results)
|
|
1621
|
+
for forward_res, gradient_in in zip(forward_results, gradient_inputs):
|
|
1622
|
+
forward_res.backward(gradient_in)
|
|
1623
|
+
backward_result_list = []
|
|
1624
|
+
for index in requires_grad_index:
|
|
1625
|
+
backward_result_list.append(getattr(inputs[index], "grad"))
|
|
1626
|
+
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_list]
|
|
1627
|
+
|
|
1628
|
+
return res_compute_element_list
|
|
1629
|
+
|
|
1630
|
+
|
|
1631
|
+
api_runner = ApiRunner()
|
|
1632
|
+
|
|
1633
|
+
# ======== 数据结构类 ========
|
|
1634
|
+
|
|
1635
|
+
class ResultCsvEntry:
|
|
1636
|
+
def __init__(self) -> None:
|
|
1637
|
+
self.forward_pass_status = None
|
|
1638
|
+
self.backward_pass_status = None
|
|
1639
|
+
self.forward_err_msg = ""
|
|
1640
|
+
self.backward_err_msg = ""
|
|
1641
|
+
self.overall_err_msg = None
|
|
1642
|
+
|
|
1643
|
+
class ProcessResultPacket:
|
|
1644
|
+
def __init__(self, process_status, result, err_msg) -> None:
|
|
1645
|
+
self.process_status = process_status
|
|
1646
|
+
self.result = result
|
|
1647
|
+
self.err_msg = err_msg
|
|
1648
|
+
|
|
1649
|
+
class MstensorMetaData:
|
|
1650
|
+
def __init__(self, dtype_str, npy_path, maximum, minimum, shape) -> None:
|
|
1651
|
+
self.dtype_str = dtype_str
|
|
1652
|
+
self.npy_path = npy_path
|
|
1653
|
+
self.maximum = maximum
|
|
1654
|
+
self.minimum = minimum
|
|
1655
|
+
self.shape = shape
|
|
1656
|
+
|
|
1657
|
+
|
|
1658
|
+
class DtypeMetaData:
|
|
1659
|
+
def __init__(self, dtype_str) -> None:
|
|
1660
|
+
self.dtype_str = dtype_str
|
|
1661
|
+
|
|
1662
|
+
|
|
1663
|
+
class ComputeElement:
|
|
1664
|
+
def __init__(self, compute_element_info=None, parameter=None):
|
|
1665
|
+
self.supported_parameter_type = tuple(type_to_api_info_type_str.keys()) + tuple([torch.Tensor, tuple])
|
|
1666
|
+
if parameter is not None:
|
|
1667
|
+
self._init_with_parameter(parameter)
|
|
1668
|
+
elif isinstance(compute_element_info, (list, dict)):
|
|
1669
|
+
self._init_from_compute_element_info(compute_element_info)
|
|
1670
|
+
elif compute_element_info is None:
|
|
1671
|
+
self._init_from_null_compute_element_info()
|
|
1672
|
+
else:
|
|
1673
|
+
pass
|
|
1674
|
+
logger.error_log_with_exp(
|
|
1675
|
+
"ComputeElement.__init__ failed: not init with parameter or compute_element info is not (list, dict)",
|
|
1676
|
+
ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
1677
|
+
|
|
1678
|
+
@staticmethod
|
|
1679
|
+
def transfer_to_torch_tensor(ms_tensor):
|
|
1680
|
+
'''
|
|
1681
|
+
Args:
|
|
1682
|
+
ms_tensor: mindspore.Tensor
|
|
1683
|
+
Return:
|
|
1684
|
+
torch_tensor: torch.Tensor
|
|
1685
|
+
'''
|
|
1686
|
+
ms_dtype = ms_tensor.dtype
|
|
1687
|
+
dtype_str = ms_dtype_to_dtype_str.get(ms_dtype)
|
|
1688
|
+
if dtype_str not in dtype_str_to_torch_dtype:
|
|
1689
|
+
err_msg = f"ComputeElement.transfer_to_torch_tensor failed: no matching torch dtype for {dtype_str}"
|
|
1690
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
1691
|
+
else:
|
|
1692
|
+
torch_dtype = dtype_str_to_torch_dtype.get(dtype_str)
|
|
1693
|
+
|
|
1694
|
+
if dtype_str in int_dtype_str_list:
|
|
1695
|
+
middle_dtype = mindspore.int64
|
|
1696
|
+
else:
|
|
1697
|
+
middle_dtype = mindspore.float64
|
|
1698
|
+
np_ndarray = ms_tensor.astype(middle_dtype).numpy()
|
|
1699
|
+
torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype)
|
|
1700
|
+
return torch_tensor
|
|
1701
|
+
|
|
1702
|
+
@staticmethod
|
|
1703
|
+
def transfer_to_mindtorch_tensor(ms_tensor):
|
|
1704
|
+
"""
|
|
1705
|
+
Args:
|
|
1706
|
+
ms_tensor: mindspore.Tensor
|
|
1707
|
+
Return:
|
|
1708
|
+
mindtorch_tensor: mindtorch.Tensor
|
|
1709
|
+
"""
|
|
1710
|
+
|
|
1711
|
+
ms_dtype = ms_tensor.dtype
|
|
1712
|
+
|
|
1713
|
+
dtype_str = ms_dtype_to_dtype_str.get(ms_dtype)
|
|
1714
|
+
|
|
1715
|
+
if dtype_str not in dtype_str_to_mindtorch_dtype:
|
|
1716
|
+
err_msg = f"ComputeElement.transfer_to_mindtorch_tensor failed: no matching mindtorch dtype for {dtype_str}"
|
|
1717
|
+
logger.error_log_with_exp(err_msg,
|
|
1718
|
+
ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
1719
|
+
else:
|
|
1720
|
+
mindtorch_dtype = dtype_str_to_mindtorch_dtype.get(dtype_str)
|
|
1721
|
+
|
|
1722
|
+
if dtype_str in int_dtype_str_list:
|
|
1723
|
+
middle_dtype = mindspore.int64
|
|
1724
|
+
else:
|
|
1725
|
+
middle_dtype = mindspore.float64
|
|
1726
|
+
|
|
1727
|
+
np_ndarray = ms_tensor.astype(middle_dtype).numpy()
|
|
1728
|
+
|
|
1729
|
+
mindtorch_tensor = mindtorch.from_numpy(np_ndarray).to(ms_dtype)
|
|
1730
|
+
|
|
1731
|
+
return mindtorch_tensor
|
|
1732
|
+
|
|
1733
|
+
@staticmethod
|
|
1734
|
+
def transfer_to_mindspore_tensor(torch_tensor):
|
|
1735
|
+
'''
|
|
1736
|
+
Args:
|
|
1737
|
+
torch_tensor: torch.Tensor
|
|
1738
|
+
|
|
1739
|
+
Return:
|
|
1740
|
+
ms_tensor: mindspore.Tensor
|
|
1741
|
+
'''
|
|
1742
|
+
torch_dtype = torch_tensor.dtype
|
|
1743
|
+
dtype_str = torch_dtype_to_dtype_str.get(torch_dtype)
|
|
1744
|
+
if dtype_str not in dtype_str_to_ms_dtype:
|
|
1745
|
+
err_msg = \
|
|
1746
|
+
f"ComputeElement._transfer_to_mindspore_tensor failed: no matching mindspore dtype for {dtype_str}"
|
|
1747
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
1748
|
+
else:
|
|
1749
|
+
ms_dtype = dtype_str_to_ms_dtype.get(dtype_str)
|
|
1750
|
+
|
|
1751
|
+
if dtype_str in int_dtype_str_list:
|
|
1752
|
+
middle_dtype = torch.int64
|
|
1753
|
+
else:
|
|
1754
|
+
middle_dtype = torch.float64
|
|
1755
|
+
np_ndarray = torch_tensor.to(middle_dtype, copy=True).numpy()
|
|
1756
|
+
ms_tensor = mindspore.Tensor.from_numpy(np_ndarray).astype(ms_dtype)
|
|
1757
|
+
return ms_tensor
|
|
1758
|
+
|
|
1759
|
+
@staticmethod
|
|
1760
|
+
def convert_inf_to_real_num(value, dtype_str):
|
|
1761
|
+
if value == float("inf"):
|
|
1762
|
+
np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
|
|
1763
|
+
value = np.finfo(np_dtype).max
|
|
1764
|
+
elif value == float("-inf"):
|
|
1765
|
+
np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
|
|
1766
|
+
value = np.finfo(np_dtype).min
|
|
1767
|
+
return value
|
|
1768
|
+
|
|
1769
|
+
def get_parameter(self, get_origin=True, tensor_platform=Const.MS_FRAMEWORK):
|
|
1770
|
+
'''
|
|
1771
|
+
Args:
|
|
1772
|
+
get_origin: boolean
|
|
1773
|
+
tensor_platform: str, Union["mindspore", "pytorch"]
|
|
1774
|
+
|
|
1775
|
+
Return:
|
|
1776
|
+
parameter: Union[int, float, str, slice, tuple, torch.Tensor, mindspore.Tensor]
|
|
1777
|
+
'''
|
|
1778
|
+
if self.parameter is None:
|
|
1779
|
+
return self.parameter
|
|
1780
|
+
if isinstance(self.parameter, tuple):
|
|
1781
|
+
return tuple([compute_element.get_parameter(get_origin=get_origin, tensor_platform=tensor_platform)
|
|
1782
|
+
for compute_element in self.parameter])
|
|
1783
|
+
elif isinstance(self.parameter, self.supported_parameter_type):
|
|
1784
|
+
parameter_tmp = self.parameter
|
|
1785
|
+
elif isinstance(self.parameter, DtypeMetaData):
|
|
1786
|
+
if tensor_platform == Const.MS_FRAMEWORK:
|
|
1787
|
+
parameter_tmp = dtype_str_to_ms_dtype.get(self.parameter.dtype_str)
|
|
1788
|
+
elif tensor_platform == Const.PT_FRAMEWORK:
|
|
1789
|
+
parameter_tmp = dtype_str_to_torch_dtype.get(self.parameter.dtype_str)
|
|
1790
|
+
elif tensor_platform == Const.MT_FRAMEWORK:
|
|
1791
|
+
parameter_tmp = dtype_str_to_mindtorch_dtype.get(self.parameter.dtype_str)
|
|
1792
|
+
|
|
1793
|
+
elif isinstance(self.parameter, MstensorMetaData):
|
|
1794
|
+
mstensor_meta_data = self.parameter
|
|
1795
|
+
ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
|
|
1796
|
+
if global_context.get_is_constructed():
|
|
1797
|
+
np_dtype = dtype_str_to_np_dtype.get(mstensor_meta_data.dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
|
|
1798
|
+
ndarray = self._construct_ndarray(mstensor_meta_data.shape, mstensor_meta_data.maximum,
|
|
1799
|
+
mstensor_meta_data.minimum, np_dtype)
|
|
1800
|
+
else:
|
|
1801
|
+
ndarray = load_npy(mstensor_meta_data.npy_path)
|
|
1802
|
+
parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype)
|
|
1803
|
+
else:
|
|
1804
|
+
err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \
|
|
1805
|
+
"(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)"
|
|
1806
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
1807
|
+
|
|
1808
|
+
# if necessary, do transfer
|
|
1809
|
+
if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
|
|
1810
|
+
parameter = self.transfer_to_torch_tensor(parameter_tmp)
|
|
1811
|
+
elif not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.MT_FRAMEWORK:
|
|
1812
|
+
parameter = self.transfer_to_mindtorch_tensor(parameter_tmp)
|
|
1813
|
+
elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform == Const.MS_FRAMEWORK:
|
|
1814
|
+
parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
|
|
1815
|
+
else:
|
|
1816
|
+
parameter = parameter_tmp
|
|
1817
|
+
|
|
1818
|
+
return parameter
|
|
1819
|
+
|
|
1820
|
+
def get_shape(self):
|
|
1821
|
+
return self.shape
|
|
1822
|
+
|
|
1823
|
+
def get_dtype(self):
|
|
1824
|
+
return self.dtype_str
|
|
1825
|
+
|
|
1826
|
+
def _construct_ndarray(self, shape, maximum, minimum, np_dtype):
|
|
1827
|
+
shape = tuple(shape)
|
|
1828
|
+
np.random.seed({random_seed})
|
|
1829
|
+
if np_dtype == np.bool_:
|
|
1830
|
+
ndarray = np.random.rand(*shape) > 0.5
|
|
1831
|
+
else:
|
|
1832
|
+
maximum = self.convert_inf_to_real_num(maximum, np_dtype)
|
|
1833
|
+
minimum = self.convert_inf_to_real_num(minimum, np_dtype)
|
|
1834
|
+
ndarray = np.random.uniform(minimum, maximum, shape).astype(np_dtype)
|
|
1835
|
+
return ndarray
|
|
1836
|
+
|
|
1837
|
+
def _init_from_null_compute_element_info(self):
|
|
1838
|
+
self.parameter = None
|
|
1839
|
+
self.shape = tuple()
|
|
1840
|
+
self.dtype = "None"
|
|
1841
|
+
|
|
1842
|
+
def _init_from_compute_element_info(self, compute_element_info):
|
|
1843
|
+
'''
|
|
1844
|
+
Args:
|
|
1845
|
+
compute_element_info: Union[list, dict]
|
|
1846
|
+
|
|
1847
|
+
Return:
|
|
1848
|
+
void
|
|
1849
|
+
|
|
1850
|
+
init member attributes: self.shape, self.dtype_str, self.parameter
|
|
1851
|
+
'''
|
|
1852
|
+
if isinstance(compute_element_info, list):
|
|
1853
|
+
self.shape = tuple()
|
|
1854
|
+
self.dtype_str = TUPLE_TYPE_STR
|
|
1855
|
+
self.parameter = tuple([ComputeElement(compute_element_info=sub_info)
|
|
1856
|
+
for sub_info in compute_element_info])
|
|
1857
|
+
else:
|
|
1858
|
+
type_str = check_and_get_from_json_dict(compute_element_info, "type", "type field in api_info.json",
|
|
1859
|
+
accepted_type=str, accepted_value=api_info_type_str_to_type.keys())
|
|
1860
|
+
self.shape = tuple()
|
|
1861
|
+
self.dtype_str = type_str
|
|
1862
|
+
if type_str == MINDSPORE_TENSOR_TYPE_STR:
|
|
1863
|
+
self._init_from_mstensor_compute_element_info(compute_element_info)
|
|
1864
|
+
else:
|
|
1865
|
+
value = check_and_get_from_json_dict(compute_element_info, "value", "value field in api_info.json")
|
|
1866
|
+
if type_str == MINDSPORE_DTYPE_TYPE_STR:
|
|
1867
|
+
self.parameter = DtypeMetaData(value)
|
|
1868
|
+
elif type_str == SLICE_TYPE_STR:
|
|
1869
|
+
self.parameter = slice(*tuple(value))
|
|
1870
|
+
else: # type_str in ("str", "int", "float", "bool")
|
|
1871
|
+
self.parameter = value
|
|
1872
|
+
|
|
1873
|
+
def _init_from_mstensor_compute_element_info(self, compute_element_info):
|
|
1874
|
+
'''
|
|
1875
|
+
do not load real tensor, only record meta data
|
|
1876
|
+
'''
|
|
1877
|
+
dtype_str = check_and_get_from_json_dict(compute_element_info, "dtype", "dtype field in api_info.json",
|
|
1878
|
+
accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys())
|
|
1879
|
+
shape = check_and_get_from_json_dict(compute_element_info, "shape", "shape field in api_info.json",
|
|
1880
|
+
accepted_type=(list,))
|
|
1881
|
+
if global_context.get_is_constructed():
|
|
1882
|
+
maximum = check_and_get_from_json_dict(compute_element_info, "Max", "Max field in api_info.json",
|
|
1883
|
+
accepted_type=(int, float))
|
|
1884
|
+
minimum = check_and_get_from_json_dict(compute_element_info, "Min", "Min field in api_info.json",
|
|
1885
|
+
accepted_type=(int, float))
|
|
1886
|
+
|
|
1887
|
+
npy_path = None
|
|
1888
|
+
else:
|
|
1889
|
+
maximum, minimum = None, None
|
|
1890
|
+
data_name = check_and_get_from_json_dict(compute_element_info, "data_name",
|
|
1891
|
+
"data_name field in api_info.json", accepted_type=(str,))
|
|
1892
|
+
npy_path = os.path.join(global_context.get_dump_data_dir(), data_name)
|
|
1893
|
+
mstensor_meta_data = MstensorMetaData(dtype_str, npy_path, maximum, minimum, shape)
|
|
1894
|
+
self.parameter = mstensor_meta_data
|
|
1895
|
+
self.dtype_str = dtype_str
|
|
1896
|
+
self.shape = tuple(shape)
|
|
1897
|
+
|
|
1898
|
+
def _init_with_parameter(self, parameter):
|
|
1899
|
+
self.parameter = parameter
|
|
1900
|
+
print(f"parameter:{parameter}")
|
|
1901
|
+
print(f"self.supported_parameter_type:{self.supported_parameter_type}")
|
|
1902
|
+
if isinstance(parameter, dict):
|
|
1903
|
+
# 这里假设 dict 中有 'type'、'shape'、'dtype' 等字段
|
|
1904
|
+
return self._init_from_compute_element_info(parameter)
|
|
1905
|
+
self.shape = tuple()
|
|
1906
|
+
if not isinstance(parameter, self.supported_parameter_type):
|
|
1907
|
+
err_msg = "ComputeElement._init_with_parameter failed: " \
|
|
1908
|
+
"parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)"
|
|
1909
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
1910
|
+
if isinstance(parameter, mindspore.Tensor):
|
|
1911
|
+
self.shape = tuple(parameter.shape)
|
|
1912
|
+
self.dtype_str = ms_dtype_to_dtype_str.get(parameter.dtype)
|
|
1913
|
+
elif isinstance(parameter, torch.Tensor):
|
|
1914
|
+
self.shape = tuple(parameter.shape)
|
|
1915
|
+
self.dtype_str = torch_dtype_to_dtype_str.get(parameter.dtype)
|
|
1916
|
+
elif isinstance(parameter, typing.Type):
|
|
1917
|
+
self.dtype_str = MINDSPORE_DTYPE_TYPE_STR
|
|
1918
|
+
self.parameter = DtypeMetaData(ms_dtype_to_dtype_str.get(parameter))
|
|
1919
|
+
elif isinstance(parameter, torch.dtype):
|
|
1920
|
+
self.dtype_str = TORCH_DTYPE_TYPE_STR
|
|
1921
|
+
self.parameter = DtypeMetaData(torch_dtype_to_dtype_str.get(parameter))
|
|
1922
|
+
elif isinstance(parameter, tuple):
|
|
1923
|
+
self.dtype_str = TUPLE_TYPE_STR
|
|
1924
|
+
self.parameter = tuple([ComputeElement(parameter=param) for param in parameter])
|
|
1925
|
+
else:
|
|
1926
|
+
self.dtype_str = type_to_api_info_type_str.get(type(parameter))
|
|
1927
|
+
print(f"self.dtype_str{self.dtype_str}")
|
|
1928
|
+
|
|
1929
|
+
class BasicInfoAndStatus:
|
|
1930
|
+
def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
|
|
1931
|
+
self.api_name = api_name
|
|
1932
|
+
self.bench_dtype = bench_dtype
|
|
1933
|
+
self.tested_dtype = tested_dtype
|
|
1934
|
+
self.shape = shape
|
|
1935
|
+
self.status = status
|
|
1936
|
+
self.err_msg = err_msg
|
|
1937
|
+
|
|
1938
|
+
# ======== api执行类 =======
|
|
1939
|
+
|
|
1940
|
+
def get_input(propagation):
|
|
1941
|
+
args_info_forward = {args_info_forward}
|
|
1942
|
+
kwargs_info_forward = {kwargs_info_forward}
|
|
1943
|
+
args_info_backward = {args_info_backward}
|
|
1944
|
+
forward_inputs = [ComputeElement(compute_element_info=compute_element_info)
|
|
1945
|
+
for compute_element_info in args_info_forward]
|
|
1946
|
+
kwargs_compute_element_dict = {
|
|
1947
|
+
key_str: ComputeElement(compute_element_info=compute_element_info)
|
|
1948
|
+
for key_str, compute_element_info in kwargs_info_forward.items()
|
|
1949
|
+
}
|
|
1950
|
+
if args_info_backward:
|
|
1951
|
+
gradient_inputs = [ComputeElement(compute_element_info=compute_element_info)
|
|
1952
|
+
for compute_element_info in args_info_backward]
|
|
1953
|
+
else:
|
|
1954
|
+
gradient_inputs = None
|
|
1955
|
+
return ApiInputAggregation(
|
|
1956
|
+
forward_inputs,
|
|
1957
|
+
kwargs_compute_element_dict,
|
|
1958
|
+
gradient_inputs
|
|
1959
|
+
)
|
|
1960
|
+
|
|
1961
|
+
# 运行和比对函数
|
|
1962
|
+
def run_and_compare_helper(api_name_str, api_input_aggregation, forward_or_backward):
|
|
1963
|
+
"""
|
|
1964
|
+
Args:
|
|
1965
|
+
api_info: ApiInfo
|
|
1966
|
+
api_name_str: str
|
|
1967
|
+
api_input_aggregation: ApiInputAggregation
|
|
1968
|
+
forward_or_backward: str: Union["forward", "backward"]
|
|
1969
|
+
|
|
1970
|
+
Return:
|
|
1971
|
+
output_list: List[tuple(str, str, BasicInfoAndStatus, dict{str: CompareResult})]
|
|
1972
|
+
|
|
1973
|
+
Description:
|
|
1974
|
+
get mindspore api output, run torch api and get output.
|
|
1975
|
+
compare output.
|
|
1976
|
+
record compare result.
|
|
1977
|
+
"""
|
|
1978
|
+
# get output
|
|
1979
|
+
if forward_or_backward == Const.FORWARD:
|
|
1980
|
+
tested_outputs, inputs, kwargs, forward_result_tuple = api_runner(api_input_aggregation, api_name_str,
|
|
1981
|
+
forward_or_backward,
|
|
1982
|
+
global_context.get_framework())
|
|
1983
|
+
print(f"inputs:{inputs}")
|
|
1984
|
+
print(f"kwargs:{kwargs}")
|
|
1985
|
+
print(f"forward_result_tuple:{forward_result_tuple}")
|
|
1986
|
+
elif forward_or_backward == Const.BACKWARD:
|
|
1987
|
+
tested_outputs, gradient_inputs, backward_result_tuple = api_runner(api_input_aggregation, api_name_str,
|
|
1988
|
+
forward_or_backward,
|
|
1989
|
+
global_context.get_framework())
|
|
1990
|
+
print(f"gradient_inputs:{gradient_inputs}")
|
|
1991
|
+
print(f"backward_result_tuple:{backward_result_tuple}")
|
|
1992
|
+
else:
|
|
1993
|
+
tested_outputs = api_runner(api_input_aggregation, api_name_str,
|
|
1994
|
+
forward_or_backward, global_context.get_framework())
|
|
1995
|
+
|
|
1996
|
+
bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK)
|
|
1997
|
+
|
|
1998
|
+
tested_outputs = trim_output_compute_element_list(tested_outputs, forward_or_backward)
|
|
1999
|
+
bench_outputs = trim_output_compute_element_list(bench_outputs, forward_or_backward)
|
|
2000
|
+
|
|
2001
|
+
# compare output
|
|
2002
|
+
output_list = []
|
|
2003
|
+
for i, (bench_out, tested_out) in enumerate(zip(bench_outputs, tested_outputs)):
|
|
2004
|
+
api_name_with_slot = Const.SEP.join([api_name_str, forward_or_backward, Const.OUTPUT, str(i)])
|
|
2005
|
+
bench_dtype = bench_out.get_dtype()
|
|
2006
|
+
tested_dtype = tested_out.get_dtype()
|
|
2007
|
+
shape = bench_out.get_shape()
|
|
2008
|
+
|
|
2009
|
+
compare_result_dict = dict()
|
|
2010
|
+
for compare_algorithm_name, compare_algorithm in compare_algorithms.items():
|
|
2011
|
+
compare_result = compare_algorithm(bench_out, tested_out)
|
|
2012
|
+
compare_result_dict[compare_algorithm_name] = compare_result
|
|
2013
|
+
|
|
2014
|
+
if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \
|
|
2015
|
+
compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
|
|
2016
|
+
status = CompareConst.PASS
|
|
2017
|
+
err_msg = ""
|
|
2018
|
+
else:
|
|
2019
|
+
status = CompareConst.ERROR
|
|
2020
|
+
err_msg = (compare_result_dict.get(CompareConst.COSINE).err_msg +
|
|
2021
|
+
compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg)
|
|
2022
|
+
|
|
2023
|
+
# self.pre_forward_hook(api_name_str, None, inputs, kwargs)
|
|
2024
|
+
basic_info_status = \
|
|
2025
|
+
BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
|
|
2026
|
+
output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
|
|
2027
|
+
return output_list
|
|
2028
|
+
|
|
2029
|
+
|
|
2030
|
+
if __name__ == "__main__":
|
|
2031
|
+
framework = "{framework}"
|
|
2032
|
+
dump_data_dir = "{real_data_path}"
|
|
2033
|
+
api_name = "{api_name}"
|
|
2034
|
+
api_full_name = "{api_full_name}"
|
|
2035
|
+
api_name_str = ".".join(api_full_name.split(".")[:3])
|
|
2036
|
+
propagation = "{propagation}"
|
|
2037
|
+
data_mode = "{data_mode}"
|
|
2038
|
+
seed_all({random_seed})
|
|
2039
|
+
|
|
2040
|
+
data_manager = DataManager("./op_result_output", None)
|
|
2041
|
+
create_directory("./op_result_output")
|
|
2042
|
+
|
|
2043
|
+
is_constructed = data_mode == "random_data"
|
|
2044
|
+
global_context.init(is_constructed, dump_data_dir, framework)
|
|
2045
|
+
|
|
2046
|
+
for i in range({iter_times}):
|
|
2047
|
+
print(f"iter: {{i}}:")
|
|
2048
|
+
if propagation == BACKWARD:
|
|
2049
|
+
|
|
2050
|
+
|
|
2051
|
+
backward_inputs_aggregation = get_input(propagation)
|
|
2052
|
+
|
|
2053
|
+
backward_output_list = run_and_compare_helper(api_name_str, backward_inputs_aggregation,
|
|
2054
|
+
Const.BACKWARD)
|
|
2055
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS,
|
|
2056
|
+
result=backward_output_list, err_msg="")
|
|
2057
|
+
|
|
2058
|
+
|
|
2059
|
+
if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS:
|
|
2060
|
+
data_manager.record(process_result_packet.result)
|
|
2061
|
+
elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP:
|
|
2062
|
+
data_manager.record_exception_skip(api_name_str, Const.BACKWARD, process_result_packet.err_msg)
|
|
2063
|
+
|
|
2064
|
+
data_manager.save_results(api_name_str)
|
|
2065
|
+
else:
|
|
2066
|
+
forward_inputs_aggregation = get_input(propagation)
|
|
2067
|
+
|
|
2068
|
+
forward_output_list = run_and_compare_helper(api_name_str, forward_inputs_aggregation,
|
|
2069
|
+
Const.FORWARD)
|
|
2070
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS,
|
|
2071
|
+
result=forward_output_list, err_msg="")
|
|
2072
|
+
|
|
2073
|
+
|
|
2074
|
+
if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS:
|
|
2075
|
+
data_manager.record(process_result_packet.result)
|
|
2076
|
+
elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP:
|
|
2077
|
+
data_manager.record_exception_skip(api_name_str, Const.FORWARD, process_result_packet.err_msg)
|
|
2078
|
+
|
|
2079
|
+
data_manager.save_results(api_name_str)
|
|
2080
|
+
|
|
2081
|
+
print("Compare finished.")
|