mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import math
|
|
4
|
+
from enum import Enum, auto
|
|
5
|
+
import torch
|
|
6
|
+
try:
|
|
7
|
+
import torch_npu
|
|
8
|
+
except ImportError:
|
|
9
|
+
pass
|
|
10
|
+
from tabulate import tabulate
|
|
11
|
+
|
|
12
|
+
TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
|
|
13
|
+
TORCH_BOOL_TYPE = ["torch.bool"]
|
|
14
|
+
TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int",
|
|
15
|
+
"torch.int64", "torch.long"]
|
|
16
|
+
TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float",
|
|
17
|
+
"torch.float64", "torch.double"]
|
|
18
|
+
TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"]
|
|
19
|
+
RAISE_PRECISION = {{
|
|
20
|
+
"torch.float16": torch.float32,
|
|
21
|
+
"torch.half": torch.float32,
|
|
22
|
+
"torch.bfloat16": torch.float32,
|
|
23
|
+
"torch.float32": torch.float64,
|
|
24
|
+
"torch.float": torch.float64
|
|
25
|
+
}}
|
|
26
|
+
THOUSANDTH_THRESHOLDING = 0.001
|
|
27
|
+
BACKWARD = 'backward'
|
|
28
|
+
|
|
29
|
+
class CompareStandard(Enum):
|
|
30
|
+
BINARY_EQUALITY_STANDARD = auto()
|
|
31
|
+
ABSOLUTE_THRESHOLD_STANDARD = auto()
|
|
32
|
+
ULP_ERROR_STANDARD = auto()
|
|
33
|
+
BENCHMARK_STANDARD = auto()
|
|
34
|
+
THOUSANDTH_STANDARD = auto()
|
|
35
|
+
|
|
36
|
+
def load_pt(pt_path, to_cpu=False):
|
|
37
|
+
pt_path = os.path.realpath(pt_path)
|
|
38
|
+
try:
|
|
39
|
+
if to_cpu:
|
|
40
|
+
pt = torch.load(pt_path, map_location=torch.device("cpu"))
|
|
41
|
+
else:
|
|
42
|
+
pt = torch.load(pt_path)
|
|
43
|
+
except Exception as e:
|
|
44
|
+
raise RuntimeError(f"load pt file {{pt_path}} failed") from e
|
|
45
|
+
return pt
|
|
46
|
+
|
|
47
|
+
def get_device():
|
|
48
|
+
if torch.cuda.is_available():
|
|
49
|
+
device = torch.device("cuda")
|
|
50
|
+
elif torch_npu.npu.is_available():
|
|
51
|
+
device = torch.device("npu")
|
|
52
|
+
else:
|
|
53
|
+
raise Exception("Error: This device is not NPU or GPU!")
|
|
54
|
+
return device
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def generate_bool_tensor(low, high, shape):
|
|
58
|
+
low, high = int(low), int(high)
|
|
59
|
+
tensor = torch.randint(low, high + 1, shape)
|
|
60
|
+
bool_tensor = torch.gt(tensor, 0)
|
|
61
|
+
return bool_tensor
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def generate_numerical_tensor(low, high, shape, data_dtype):
|
|
65
|
+
if data_dtype in TORCH_FLOAT_TYPE:
|
|
66
|
+
scale = high - low
|
|
67
|
+
rand01 = torch.rand(shape, dtype=eval(data_dtype))
|
|
68
|
+
tensor = rand01 * scale + low
|
|
69
|
+
elif data_dtype in TORCH_INT_TYPE:
|
|
70
|
+
low, high = int(low), int(high)
|
|
71
|
+
tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype))
|
|
72
|
+
else:
|
|
73
|
+
raise NotImplementedError(f"{{data_dtype}} is not supported!")
|
|
74
|
+
if torch.numel(tensor) == 0:
|
|
75
|
+
return tensor
|
|
76
|
+
tmp_tensor = tensor.reshape(-1)
|
|
77
|
+
tmp_tensor[0] = low
|
|
78
|
+
tmp_tensor[-1] = high
|
|
79
|
+
data = tmp_tensor.reshape(shape)
|
|
80
|
+
return data
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def generate_random_tensor(info):
|
|
84
|
+
low, high = info.get('Min'), info.get('Max')
|
|
85
|
+
data_dtype = info.get('dtype')
|
|
86
|
+
shape = tuple(info.get('shape'))
|
|
87
|
+
if data_dtype == "torch.bool":
|
|
88
|
+
data = generate_bool_tensor(low, high, shape)
|
|
89
|
+
else:
|
|
90
|
+
data = generate_numerical_tensor(low, high, shape, data_dtype)
|
|
91
|
+
return data
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def generate_real_tensor(data_path):
|
|
95
|
+
data_path = os.path.realpath(data_path)
|
|
96
|
+
data = load_pt(data_path, to_cpu = True)
|
|
97
|
+
return data
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def generate_data(info):
|
|
101
|
+
data_type = info.get("type")
|
|
102
|
+
data_path = info.get("data_name")
|
|
103
|
+
data_grad = info.get("requires_grad")
|
|
104
|
+
if data_type in TENSOR_DATA_LIST:
|
|
105
|
+
if data_path:
|
|
106
|
+
data = generate_real_tensor(data_path)
|
|
107
|
+
else:
|
|
108
|
+
data = generate_random_tensor(info)
|
|
109
|
+
else:
|
|
110
|
+
data = info.get("value")
|
|
111
|
+
if data_grad == True:
|
|
112
|
+
data.requires_grad_(True)
|
|
113
|
+
return data
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def get_input(propagation):
|
|
117
|
+
{args_element_assignment}
|
|
118
|
+
args_device = [{args_list_generator_device}]
|
|
119
|
+
args_bench = [{args_list_generator_bench}]
|
|
120
|
+
{kwargs_value_assignment}
|
|
121
|
+
kwargs_device = {{{kwargs_dict_generator_device}}}
|
|
122
|
+
kwargs_bench = {{{kwargs_dict_generator_bench}}}
|
|
123
|
+
{args_element_assignment_backward}
|
|
124
|
+
args_device_backward = [{args_list_generator_device_backward}]
|
|
125
|
+
args_bench_backward = [{args_list_generator_bench_backward}]
|
|
126
|
+
if propagation == BACKWARD:
|
|
127
|
+
return args_device, kwargs_device, args_bench, kwargs_bench, args_device_backward, args_bench_backward
|
|
128
|
+
return args_device, kwargs_device, args_bench, kwargs_bench
|
|
129
|
+
|
|
130
|
+
def exec_api(args, kwargs, args_grad_input, propagation):
|
|
131
|
+
output = {api_type}.{api_name}(*args, **kwargs)
|
|
132
|
+
if propagation == BACKWARD:
|
|
133
|
+
args_input_tensor = [tensor for tensor in args if isinstance(tensor, torch.Tensor) and tensor.requires_grad]
|
|
134
|
+
args_input_tensor.extend(
|
|
135
|
+
[value for value in kwargs.values() if isinstance(value, torch.Tensor) and value.requires_grad])
|
|
136
|
+
output_backward = torch.autograd.grad(outputs=output, inputs=args_input_tensor, grad_outputs=args_grad_input)
|
|
137
|
+
return output_backward
|
|
138
|
+
return output
|
|
139
|
+
|
|
140
|
+
def compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol):
|
|
141
|
+
out_bench = out_bench.to(out_device.dtype)
|
|
142
|
+
min = torch.finfo(out_device.dtype).min
|
|
143
|
+
max = torch.finfo(out_device.dtype).max
|
|
144
|
+
bench_clip = torch.clamp(out_bench, min=min, max=max)
|
|
145
|
+
device_clip = torch.clamp(out_device, min=min, max=max)
|
|
146
|
+
clipped_abs_ae = torch.abs(device_clip - bench_clip)
|
|
147
|
+
clipped_re = clipped_abs_ae / abs_bench_with_eps
|
|
148
|
+
pass_mask = torch.less_equal(clipped_re, rtol)
|
|
149
|
+
both_nan_mask = torch.logical_and(torch.isnan(out_device), torch.isnan(bench_clip))
|
|
150
|
+
pass_mask = torch.logical_or(pass_mask, both_nan_mask)
|
|
151
|
+
not_pass_mask = torch.logical_not(pass_mask)
|
|
152
|
+
not_pass_mask = torch.logical_and(not_pass_mask, inf_nan_mask)
|
|
153
|
+
inf_nan_err_cnt = torch.sum(not_pass_mask)
|
|
154
|
+
return 0 if torch.sum(inf_nan_mask) == 0 else inf_nan_err_cnt / torch.sum(inf_nan_mask)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def compute_rmse(abs_err, normal_value_mask):
|
|
158
|
+
if torch.sum(normal_value_mask) == 0:
|
|
159
|
+
return 0
|
|
160
|
+
else:
|
|
161
|
+
masked_ae = torch.where(normal_value_mask, abs_err, 0)
|
|
162
|
+
mse = torch.sum(torch.square(masked_ae)) / torch.sum(normal_value_mask)
|
|
163
|
+
rmse = torch.sqrt(mse)
|
|
164
|
+
return rmse
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def compute_error_balance(out_device, out_bench):
|
|
168
|
+
larger_count = torch.sum(torch.greater(out_device - out_bench.to(out_device.dtype), 0))
|
|
169
|
+
smaller_count = torch.sum(torch.less(out_device - out_bench.to(out_device.dtype), 0))
|
|
170
|
+
if torch.numel(out_bench) == 0:
|
|
171
|
+
raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}")
|
|
172
|
+
error_balance = abs(larger_count - smaller_count) / torch.numel(out_bench)
|
|
173
|
+
return error_balance
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def compare_tensor(out_device, out_bench, api_name):
|
|
177
|
+
if out_device.shape != out_bench.shape:
|
|
178
|
+
print("ERROR: shape of out_device and out_bench is not equal!")
|
|
179
|
+
return None
|
|
180
|
+
if torch.numel(out_bench) == 0:
|
|
181
|
+
print("Both out_device and out_bench have zero elements.")
|
|
182
|
+
return None
|
|
183
|
+
dtype_device = out_device.dtype
|
|
184
|
+
dtype_bench = out_bench.dtype
|
|
185
|
+
headers = ["Metric", "Value"]
|
|
186
|
+
table = [
|
|
187
|
+
["Shape", out_bench.shape],
|
|
188
|
+
["Dtype of out_device", out_device.dtype],
|
|
189
|
+
["Dtype of out_bench", out_bench.dtype]
|
|
190
|
+
]
|
|
191
|
+
if str(dtype_device) in TORCH_FLOAT_TYPE and str(dtype_bench) in TORCH_FLOAT_TYPE \
|
|
192
|
+
or str(dtype_device) in TORCH_INT_TYPE and str(dtype_bench) in TORCH_INT_TYPE \
|
|
193
|
+
or str(dtype_device) in TORCH_BOOL_TYPE and str(dtype_bench) in TORCH_BOOL_TYPE:
|
|
194
|
+
out_device = out_device.to(torch.device("cpu"))
|
|
195
|
+
if str(dtype_device) in TORCH_BOOL_TYPE or str(dtype_device) in TORCH_INT_TYPE or compare_standard == CompareStandard.BINARY_EQUALITY_STANDARD:
|
|
196
|
+
error_number = torch.sum(out_device != out_bench).item()
|
|
197
|
+
if torch.numel(out_bench) == 0:
|
|
198
|
+
raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}")
|
|
199
|
+
error_rate = error_number / torch.numel(out_bench)
|
|
200
|
+
table.append(["Compare Standard", "Binary Equality Standard"])
|
|
201
|
+
table.append(["Error Rate", error_rate])
|
|
202
|
+
else:
|
|
203
|
+
abs_err = torch.abs(out_device - out_bench)
|
|
204
|
+
abs_bench = torch.abs(out_bench)
|
|
205
|
+
if dtype_bench == torch.float32:
|
|
206
|
+
eps = 2 ** -23
|
|
207
|
+
if dtype_bench == torch.float64:
|
|
208
|
+
eps = 2 ** -52
|
|
209
|
+
abs_bench_with_eps = abs_bench + eps
|
|
210
|
+
rel_err = torch.abs(abs_err / abs_bench_with_eps)
|
|
211
|
+
device_finite_mask = torch.isfinite(out_device)
|
|
212
|
+
bench_finite_mask = torch.isfinite(out_bench.to(dtype_device))
|
|
213
|
+
both_finite_mask = torch.logical_and(device_finite_mask, bench_finite_mask)
|
|
214
|
+
inf_nan_mask = torch.logical_not(both_finite_mask)
|
|
215
|
+
if compare_standard == CompareStandard.ABSOLUTE_THRESHOLD_STANDARD:
|
|
216
|
+
if dtype_device == torch.float16:
|
|
217
|
+
rtol, small_value, small_value_atol = 1.0e-3, 1.0e-3, 1.0e-5
|
|
218
|
+
elif dtype_device == torch.bfloat16:
|
|
219
|
+
rtol, small_value, small_value_atol = 4.0e-3, 1.0e-3, 1.0e-5
|
|
220
|
+
else:
|
|
221
|
+
rtol, small_value, small_value_atol = 1.0e-6, 1.0e-6, 1.0e-9
|
|
222
|
+
small_value_mask = torch.less_equal(abs_bench, small_value)
|
|
223
|
+
small_value_mask = torch.logical_and(small_value_mask, both_finite_mask)
|
|
224
|
+
normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask))
|
|
225
|
+
inf_nan_proportion = compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol)
|
|
226
|
+
rel_err_mask = torch.greater(rel_err, rtol)
|
|
227
|
+
rel_err_mask = torch.logical_and(rel_err_mask, normal_value_mask)
|
|
228
|
+
if torch.sum(normal_value_mask) == 0:
|
|
229
|
+
rel_err_proportion = 0
|
|
230
|
+
else:
|
|
231
|
+
rel_err_proportion = torch.sum(rel_err_mask) / torch.sum(normal_value_mask)
|
|
232
|
+
abs_err_mask = torch.greater(abs_err, small_value_atol)
|
|
233
|
+
abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask)
|
|
234
|
+
if torch.sum(small_value_mask) == 0:
|
|
235
|
+
abs_err_proportion = 0
|
|
236
|
+
else:
|
|
237
|
+
abs_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask)
|
|
238
|
+
table.append(["Compare Standard", "Absolute Threshold Standard"])
|
|
239
|
+
table.append(["Relative Error Ratio", rel_err_proportion])
|
|
240
|
+
table.append(["Absolute Error Ratio", abs_err_proportion])
|
|
241
|
+
elif compare_standard == CompareStandard.ULP_ERROR_STANDARD:
|
|
242
|
+
if dtype_device == torch.float16:
|
|
243
|
+
min_eb, exponent_num = -14, 10
|
|
244
|
+
elif dtype_device == torch.bfloat16:
|
|
245
|
+
min_eb, exponent_num = -126, 7
|
|
246
|
+
else:
|
|
247
|
+
min_eb, exponent_num = -126, 23
|
|
248
|
+
eb = torch.where(abs_bench == 0, torch.zeros(out_bench.shape), torch.floor(torch.log2(abs_bench)))
|
|
249
|
+
eb = torch.maximum(eb, min_eb * torch.ones(out_bench.shape))
|
|
250
|
+
if dtype_device == torch.float32:
|
|
251
|
+
ulp_err = (out_device.to(torch.float64) - out_bench).to(torch.float64) * torch.exp2(-eb + exponent_num).to(torch.float64)
|
|
252
|
+
else:
|
|
253
|
+
ulp_err = (out_device.to(torch.float32) - out_bench).to(torch.float32) * torch.exp2(-eb + exponent_num).to(torch.float32)
|
|
254
|
+
ulp_err = torch.abs(ulp_err)
|
|
255
|
+
max_ulp_err = torch.max(ulp_err)
|
|
256
|
+
mean_ulp_err = torch.mean(ulp_err)
|
|
257
|
+
if torch.numel(out_bench) == 0:
|
|
258
|
+
raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}")
|
|
259
|
+
if dtype_device == torch.float32:
|
|
260
|
+
ulp_err_proportion = torch.sum(ulp_err > 32) / torch.numel(out_bench)
|
|
261
|
+
else:
|
|
262
|
+
ulp_err_proportion = torch.sum(ulp_err > 1) / torch.numel(out_bench)
|
|
263
|
+
table.append(["Compare Standard", "ULP error Standard"])
|
|
264
|
+
table.append(["Maximum ULP Error", max_ulp_err])
|
|
265
|
+
table.append(["Mean ULP Error", mean_ulp_err])
|
|
266
|
+
table.append(["ULP Error Proportion", ulp_err_proportion])
|
|
267
|
+
elif compare_standard == CompareStandard.THOUSANDTH_STANDARD:
|
|
268
|
+
rel_err_origin = torch.abs(abs_err / abs_bench_with_eps)
|
|
269
|
+
if torch.numel(rel_err_origin) == 0:
|
|
270
|
+
thousand_res = 1
|
|
271
|
+
else:
|
|
272
|
+
thousand_res = torch.divide(torch.sum(rel_err < THOUSANDTH_THRESHOLDING), torch.numel(rel_err_origin))
|
|
273
|
+
thousand_status = thousand_res > (1 - THOUSANDTH_THRESHOLDING)
|
|
274
|
+
table.append(["Compare Standard", "Thousandth Standard"])
|
|
275
|
+
table.append(["Thousandth ratio", thousand_res])
|
|
276
|
+
else:
|
|
277
|
+
if dtype_device == torch.float16:
|
|
278
|
+
small_value, small_value_atol = 1.0e-3, 1.0e-5
|
|
279
|
+
elif dtype_device == torch.bfloat16:
|
|
280
|
+
small_value, small_value_atol = 1.0e-3, 1.0e-5
|
|
281
|
+
else:
|
|
282
|
+
small_value, small_value_atol = 1.0e-6, 1.0e-9
|
|
283
|
+
small_value_mask = torch.less_equal(abs_bench, small_value)
|
|
284
|
+
small_value_mask = torch.logical_and(small_value_mask, both_finite_mask)
|
|
285
|
+
normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask))
|
|
286
|
+
abs_err_mask = torch.greater(abs_err, small_value_atol)
|
|
287
|
+
abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask)
|
|
288
|
+
if torch.sum(small_value_mask) == 0:
|
|
289
|
+
small_value_err_proportion = 0
|
|
290
|
+
else:
|
|
291
|
+
small_value_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask)
|
|
292
|
+
rel_err = torch.where(normal_value_mask, rel_err, -1 * torch.ones(out_device.shape))
|
|
293
|
+
if torch.max(rel_err) >= 0:
|
|
294
|
+
max_rel_err = torch.max(rel_err)
|
|
295
|
+
else:
|
|
296
|
+
max_rel_err = 0
|
|
297
|
+
if torch.sum(normal_value_mask) == 0:
|
|
298
|
+
mean_rel_err = 0
|
|
299
|
+
else:
|
|
300
|
+
mean_rel_err = torch.sum(torch.clamp(rel_err, min=0)) / torch.sum(normal_value_mask)
|
|
301
|
+
rmse = compute_rmse(abs_err, normal_value_mask)
|
|
302
|
+
error_balance = compute_error_balance(out_device, out_bench)
|
|
303
|
+
table.append(["Compare Standard", "Benchmark Standard"])
|
|
304
|
+
table.append(["Small Value Error Proportion", small_value_err_proportion])
|
|
305
|
+
table.append(["Maximum Relative Error", max_rel_err])
|
|
306
|
+
table.append(["Mean Relative Error", mean_rel_err])
|
|
307
|
+
table.append(["Root Mean Squared Error", rmse])
|
|
308
|
+
table.append(["Error Balance", error_balance])
|
|
309
|
+
else:
|
|
310
|
+
print(f"ERROR: out_device dtype is {{dtype_device}}, out_bench dtype is {{dtype_bench}}, not comparable.")
|
|
311
|
+
return None
|
|
312
|
+
print(tabulate(table, headers, tablefmt='grid'))
|
|
313
|
+
return None
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def compare_element(out_device, out_bench, api_name):
|
|
317
|
+
if type(out_device) != type(out_bench):
|
|
318
|
+
print("ERROR: out_device and out_bench is not the same type!")
|
|
319
|
+
return None
|
|
320
|
+
if isinstance(out_bench, torch.Tensor):
|
|
321
|
+
compare_tensor(out_device, out_bench, api_name)
|
|
322
|
+
elif isinstance(out_bench, (bool, int, float, str)):
|
|
323
|
+
if out_device == out_bench:
|
|
324
|
+
print("PASS: out_device and out_bench equals.")
|
|
325
|
+
else:
|
|
326
|
+
print("ERROR: out_device and out_bench is not equal!")
|
|
327
|
+
else:
|
|
328
|
+
print(f"ERROR: comparison of type {{type(out_bench)}} is not supported.")
|
|
329
|
+
return None
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def compare(out_device, out_bench, api_name):
|
|
333
|
+
print("Compare result:")
|
|
334
|
+
if type(out_device) != type(out_bench):
|
|
335
|
+
print("ERROR: out_device and out_bench is not the same type!")
|
|
336
|
+
return None
|
|
337
|
+
if isinstance(out_bench, (list, tuple)):
|
|
338
|
+
if len(out_device) != len(out_bench):
|
|
339
|
+
print("ERROR: len of out_device and out_bench is different!")
|
|
340
|
+
return None
|
|
341
|
+
for index, _ in enumerate(out_bench):
|
|
342
|
+
print(f"index {{index}}:")
|
|
343
|
+
compare_element(out_device[index], out_bench[index], api_name)
|
|
344
|
+
else:
|
|
345
|
+
compare_element(out_device, out_bench, api_name)
|
|
346
|
+
|
|
347
|
+
if __name__ == "__main__":
|
|
348
|
+
device = get_device()
|
|
349
|
+
api_name = "{api_name}"
|
|
350
|
+
propagation = "{propagation}"
|
|
351
|
+
compare_standard = {compare_standard}
|
|
352
|
+
torch.manual_seed({random_seed})
|
|
353
|
+
for i in range({iter_times}):
|
|
354
|
+
print(f"iter: {{i}}:")
|
|
355
|
+
if propagation == BACKWARD:
|
|
356
|
+
args_device, kwargs_device, args_bench, kwargs_bench, args_device_backward, args_bench_backward = get_input(propagation)
|
|
357
|
+
output_device = exec_api(args_device, kwargs_device, args_device_backward, propagation)
|
|
358
|
+
output_bench = exec_api(args_bench, kwargs_bench, args_bench_backward, propagation)
|
|
359
|
+
compare(output_device, output_bench, api_name)
|
|
360
|
+
else:
|
|
361
|
+
args_device, kwargs_device, args_bench, kwargs_bench = get_input(propagation)
|
|
362
|
+
output_device = exec_api(args_device, kwargs_device, None, propagation)
|
|
363
|
+
output_bench = exec_api(args_bench, kwargs_bench, None, propagation)
|
|
364
|
+
compare(output_device, output_bench, api_name)
|
|
365
|
+
print("Compare finished.")
|
|
@@ -139,7 +139,12 @@ def gen_random_tensor(info, convert_type):
|
|
|
139
139
|
high_info = [high, high_origin]
|
|
140
140
|
data_dtype = info.get('dtype')
|
|
141
141
|
shape = tuple(info.get('shape'))
|
|
142
|
-
if
|
|
142
|
+
if 0 in shape:
|
|
143
|
+
low, low_origin = 0, 0
|
|
144
|
+
high, high_origin = 0, 0
|
|
145
|
+
low_info = [low, low_origin]
|
|
146
|
+
high_info = [high, high_origin]
|
|
147
|
+
elif not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
|
|
143
148
|
error_info = f'Data info Min: {low} , Max: {high}, info type must be int or float.'
|
|
144
149
|
raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
|
|
145
150
|
if data_dtype == "torch.bool":
|
|
@@ -33,9 +33,10 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
|
|
|
33
33
|
from msprobe.pytorch.common import parse_json_info_forward_backward
|
|
34
34
|
from msprobe.pytorch.common.log import logger
|
|
35
35
|
from msprobe.core.common.file_utils import FileChecker, check_file_suffix, check_link, FileOpen, \
|
|
36
|
-
|
|
36
|
+
create_directory, load_json, save_json
|
|
37
37
|
from msprobe.core.common.file_utils import remove_path
|
|
38
|
-
from msprobe.core.common.const import FileCheckConst
|
|
38
|
+
from msprobe.core.common.const import FileCheckConst, Const
|
|
39
|
+
from msprobe.core.common.utils import CompareException
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
def split_json_file(input_file, num_splits, filter_api):
|
|
@@ -47,9 +48,11 @@ def split_json_file(input_file, num_splits, filter_api):
|
|
|
47
48
|
for data_name in list(backward_data.keys()):
|
|
48
49
|
backward_data[f"{data_name}.backward"] = backward_data.pop(data_name)
|
|
49
50
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
51
|
+
input_data = load_json(input_file)
|
|
52
|
+
if input_data.get("data") is None:
|
|
53
|
+
logger.error("Invalid input file, 'data' field is missing")
|
|
54
|
+
raise CompareException("Invalid input file, 'data' field is missing")
|
|
55
|
+
input_data.pop("data")
|
|
53
56
|
|
|
54
57
|
items = list(forward_data.items())
|
|
55
58
|
total_items = len(items)
|
|
@@ -69,8 +72,7 @@ def split_json_file(input_file, num_splits, filter_api):
|
|
|
69
72
|
}
|
|
70
73
|
}
|
|
71
74
|
split_filename = f"temp_part{i}.json"
|
|
72
|
-
|
|
73
|
-
json.dump(temp_data, split_file)
|
|
75
|
+
save_json(split_filename, temp_data)
|
|
74
76
|
split_files.append(split_filename)
|
|
75
77
|
|
|
76
78
|
return split_files, total_items
|
|
@@ -122,7 +124,7 @@ def run_parallel_ut(config):
|
|
|
122
124
|
if output == '':
|
|
123
125
|
break
|
|
124
126
|
if '[ERROR]' in output:
|
|
125
|
-
logger.warning(output
|
|
127
|
+
logger.warning(output)
|
|
126
128
|
sys.stdout.flush()
|
|
127
129
|
except ValueError as e:
|
|
128
130
|
logger.warning(f"An error occurred while reading subprocess output: {e}")
|
|
@@ -182,16 +184,19 @@ def run_parallel_ut(config):
|
|
|
182
184
|
|
|
183
185
|
|
|
184
186
|
def prepare_config(args):
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
out_path =
|
|
189
|
-
check_path_before_create(out_path)
|
|
187
|
+
api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
|
|
188
|
+
ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
189
|
+
api_info = api_info_file_checker.common_check()
|
|
190
|
+
out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
|
|
190
191
|
create_directory(out_path)
|
|
191
192
|
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
|
|
192
193
|
out_path = out_path_checker.common_check()
|
|
193
194
|
split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
|
|
194
|
-
config_path =
|
|
195
|
+
config_path = args.config_path if args.config_path else None
|
|
196
|
+
if config_path:
|
|
197
|
+
config_path_checker = FileChecker(config_path, FileCheckConst.FILE,
|
|
198
|
+
FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
|
|
199
|
+
config_path = config_path_checker.common_check()
|
|
195
200
|
result_csv_path = args.result_csv_path or os.path.join(
|
|
196
201
|
out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
|
|
197
202
|
if not args.result_csv_path:
|
|
@@ -28,11 +28,12 @@ else:
|
|
|
28
28
|
import torch
|
|
29
29
|
from tqdm import tqdm
|
|
30
30
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import generate_device_params, get_api_info
|
|
31
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api
|
|
32
|
-
from msprobe.core.common.file_utils import check_link
|
|
31
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, is_unsupported_api
|
|
32
|
+
from msprobe.core.common.file_utils import check_link, FileChecker
|
|
33
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
|
|
34
|
+
from msprobe.core.common.const import FileCheckConst, Const
|
|
33
35
|
from msprobe.pytorch.common.log import logger
|
|
34
36
|
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
35
|
-
from msprobe.core.common.const import Const
|
|
36
37
|
|
|
37
38
|
|
|
38
39
|
def check_tensor_overflow(x):
|
|
@@ -74,23 +75,25 @@ def run_overflow_check(forward_file):
|
|
|
74
75
|
logger.info("start UT test")
|
|
75
76
|
forward_content, _, real_data_path = parse_json_info_forward_backward(forward_file)
|
|
76
77
|
for api_full_name, api_info_dict in tqdm(forward_content.items()):
|
|
78
|
+
if is_unsupported_api(api_full_name, is_overflow_check=True):
|
|
79
|
+
continue
|
|
77
80
|
try:
|
|
78
81
|
run_torch_api(api_full_name, api_info_dict, real_data_path)
|
|
79
82
|
except Exception as err:
|
|
80
83
|
_, api_name, _ = api_full_name.split(Const.SEP)
|
|
81
84
|
if "not implemented for 'Half'" in str(err):
|
|
82
|
-
logger.warning(f"API {api_name} not support half tensor in CPU
|
|
83
|
-
|
|
85
|
+
logger.warning(f"API {api_name} not support half tensor in CPU. This API does not support overflow "
|
|
86
|
+
"check, so it will be skipped.")
|
|
84
87
|
elif "expected scalar type Long" in str(err):
|
|
85
88
|
logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
|
|
86
|
-
|
|
89
|
+
"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
|
|
87
90
|
else:
|
|
88
91
|
logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
|
|
89
92
|
|
|
90
93
|
|
|
91
94
|
def run_torch_api(api_full_name, api_info_dict, real_data_path):
|
|
92
95
|
torch.npu.clear_npu_overflow_flag()
|
|
93
|
-
api_type, api_name
|
|
96
|
+
api_type, api_name = extract_basic_api_segments(api_full_name)
|
|
94
97
|
args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
|
|
95
98
|
if not need_grad:
|
|
96
99
|
logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward."
|
|
@@ -135,8 +138,9 @@ def _run_overflow_check(parser=None):
|
|
|
135
138
|
def _run_overflow_check_command(args):
|
|
136
139
|
torch.npu.set_compile_mode(jit_compile=args.jit_compile)
|
|
137
140
|
npu_device = "npu:" + str(args.device_id)
|
|
138
|
-
|
|
139
|
-
|
|
141
|
+
api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
|
|
142
|
+
ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
143
|
+
api_info = api_info_file_checker.common_check()
|
|
140
144
|
try:
|
|
141
145
|
torch.npu.set_device(npu_device)
|
|
142
146
|
except Exception as error:
|