mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import math
|
|
4
|
+
from enum import Enum, auto
|
|
5
|
+
import torch
|
|
6
|
+
try:
|
|
7
|
+
import torch_npu
|
|
8
|
+
except ImportError:
|
|
9
|
+
pass
|
|
10
|
+
from tabulate import tabulate
|
|
11
|
+
|
|
12
|
+
TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
|
|
13
|
+
TORCH_BOOL_TYPE = ["torch.bool"]
|
|
14
|
+
TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int",
|
|
15
|
+
"torch.int64", "torch.long"]
|
|
16
|
+
TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float",
|
|
17
|
+
"torch.float64", "torch.double"]
|
|
18
|
+
TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"]
|
|
19
|
+
RAISE_PRECISION = {{
|
|
20
|
+
"torch.float16": torch.float32,
|
|
21
|
+
"torch.half": torch.float32,
|
|
22
|
+
"torch.bfloat16": torch.float32,
|
|
23
|
+
"torch.float32": torch.float64,
|
|
24
|
+
"torch.float": torch.float64
|
|
25
|
+
}}
|
|
26
|
+
THOUSANDTH_THRESHOLDING = 0.001
|
|
27
|
+
BACKWARD = 'backward'
|
|
28
|
+
|
|
29
|
+
class CompareStandard(Enum):
|
|
30
|
+
BINARY_EQUALITY_STANDARD = auto()
|
|
31
|
+
ABSOLUTE_THRESHOLD_STANDARD = auto()
|
|
32
|
+
ULP_ERROR_STANDARD = auto()
|
|
33
|
+
BENCHMARK_STANDARD = auto()
|
|
34
|
+
THOUSANDTH_STANDARD = auto()
|
|
35
|
+
|
|
36
|
+
def load_pt(pt_path, to_cpu=False):
|
|
37
|
+
pt_path = os.path.realpath(pt_path)
|
|
38
|
+
try:
|
|
39
|
+
if to_cpu:
|
|
40
|
+
pt = torch.load(pt_path, map_location=torch.device("cpu"))
|
|
41
|
+
else:
|
|
42
|
+
pt = torch.load(pt_path)
|
|
43
|
+
except Exception as e:
|
|
44
|
+
raise RuntimeError(f"load pt file {{pt_path}} failed") from e
|
|
45
|
+
return pt
|
|
46
|
+
|
|
47
|
+
def get_device():
|
|
48
|
+
if torch.cuda.is_available():
|
|
49
|
+
device = torch.device("cuda")
|
|
50
|
+
elif torch_npu.npu.is_available():
|
|
51
|
+
device = torch.device("npu")
|
|
52
|
+
else:
|
|
53
|
+
raise Exception("Error: This device is not NPU or GPU!")
|
|
54
|
+
return device
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def generate_bool_tensor(low, high, shape):
|
|
58
|
+
low, high = int(low), int(high)
|
|
59
|
+
tensor = torch.randint(low, high + 1, shape)
|
|
60
|
+
bool_tensor = torch.gt(tensor, 0)
|
|
61
|
+
return bool_tensor
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def generate_numerical_tensor(low, high, shape, data_dtype):
|
|
65
|
+
if data_dtype in TORCH_FLOAT_TYPE:
|
|
66
|
+
scale = high - low
|
|
67
|
+
rand01 = torch.rand(shape, dtype=eval(data_dtype))
|
|
68
|
+
tensor = rand01 * scale + low
|
|
69
|
+
elif data_dtype in TORCH_INT_TYPE:
|
|
70
|
+
low, high = int(low), int(high)
|
|
71
|
+
tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype))
|
|
72
|
+
else:
|
|
73
|
+
raise NotImplementedError(f"{{data_dtype}} is not supported!")
|
|
74
|
+
if torch.numel(tensor) == 0:
|
|
75
|
+
return tensor
|
|
76
|
+
tmp_tensor = tensor.reshape(-1)
|
|
77
|
+
tmp_tensor[0] = low
|
|
78
|
+
tmp_tensor[-1] = high
|
|
79
|
+
data = tmp_tensor.reshape(shape)
|
|
80
|
+
return data
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def generate_random_tensor(info):
|
|
84
|
+
low, high = info.get('Min'), info.get('Max')
|
|
85
|
+
data_dtype = info.get('dtype')
|
|
86
|
+
shape = tuple(info.get('shape'))
|
|
87
|
+
if data_dtype == "torch.bool":
|
|
88
|
+
data = generate_bool_tensor(low, high, shape)
|
|
89
|
+
else:
|
|
90
|
+
data = generate_numerical_tensor(low, high, shape, data_dtype)
|
|
91
|
+
return data
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def generate_real_tensor(data_path):
|
|
95
|
+
data_path = os.path.realpath(data_path)
|
|
96
|
+
data = load_pt(data_path, to_cpu = True)
|
|
97
|
+
return data
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def generate_data(info):
|
|
101
|
+
data_type = info.get("type")
|
|
102
|
+
data_path = info.get("data_name")
|
|
103
|
+
data_grad = info.get("requires_grad")
|
|
104
|
+
if data_type in TENSOR_DATA_LIST:
|
|
105
|
+
if data_path:
|
|
106
|
+
data = generate_real_tensor(data_path)
|
|
107
|
+
else:
|
|
108
|
+
data = generate_random_tensor(info)
|
|
109
|
+
else:
|
|
110
|
+
data = info.get("value")
|
|
111
|
+
if data_grad == True:
|
|
112
|
+
data.requires_grad_(True)
|
|
113
|
+
return data
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def get_input(propagation):
|
|
117
|
+
{args_element_assignment}
|
|
118
|
+
args_device = [{args_list_generator_device}]
|
|
119
|
+
args_bench = [{args_list_generator_bench}]
|
|
120
|
+
{kwargs_value_assignment}
|
|
121
|
+
kwargs_device = {{{kwargs_dict_generator_device}}}
|
|
122
|
+
kwargs_bench = {{{kwargs_dict_generator_bench}}}
|
|
123
|
+
{args_element_assignment_backward}
|
|
124
|
+
args_device_backward = [{args_list_generator_device_backward}]
|
|
125
|
+
args_bench_backward = [{args_list_generator_bench_backward}]
|
|
126
|
+
if propagation == BACKWARD:
|
|
127
|
+
return args_device, kwargs_device, args_bench, kwargs_bench, args_device_backward, args_bench_backward
|
|
128
|
+
return args_device, kwargs_device, args_bench, kwargs_bench
|
|
129
|
+
|
|
130
|
+
def exec_api(args, kwargs, args_grad_input, propagation):
|
|
131
|
+
output = {api_type}.{api_name}(*args, **kwargs)
|
|
132
|
+
if propagation == BACKWARD:
|
|
133
|
+
args_input_tensor = [tensor for tensor in args if isinstance(tensor, torch.Tensor) and tensor.requires_grad]
|
|
134
|
+
args_input_tensor.extend(
|
|
135
|
+
[value for value in kwargs.values() if isinstance(value, torch.Tensor) and value.requires_grad])
|
|
136
|
+
output_backward = torch.autograd.grad(outputs=output, inputs=args_input_tensor, grad_outputs=args_grad_input)
|
|
137
|
+
return output_backward
|
|
138
|
+
return output
|
|
139
|
+
|
|
140
|
+
def compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol):
|
|
141
|
+
out_bench = out_bench.to(out_device.dtype)
|
|
142
|
+
min = torch.finfo(out_device.dtype).min
|
|
143
|
+
max = torch.finfo(out_device.dtype).max
|
|
144
|
+
bench_clip = torch.clamp(out_bench, min=min, max=max)
|
|
145
|
+
device_clip = torch.clamp(out_device, min=min, max=max)
|
|
146
|
+
clipped_abs_ae = torch.abs(device_clip - bench_clip)
|
|
147
|
+
clipped_re = clipped_abs_ae / abs_bench_with_eps
|
|
148
|
+
pass_mask = torch.less_equal(clipped_re, rtol)
|
|
149
|
+
both_nan_mask = torch.logical_and(torch.isnan(out_device), torch.isnan(bench_clip))
|
|
150
|
+
pass_mask = torch.logical_or(pass_mask, both_nan_mask)
|
|
151
|
+
not_pass_mask = torch.logical_not(pass_mask)
|
|
152
|
+
not_pass_mask = torch.logical_and(not_pass_mask, inf_nan_mask)
|
|
153
|
+
inf_nan_err_cnt = torch.sum(not_pass_mask)
|
|
154
|
+
return 0 if torch.sum(inf_nan_mask) == 0 else inf_nan_err_cnt / torch.sum(inf_nan_mask)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def compute_rmse(abs_err, normal_value_mask):
|
|
158
|
+
if torch.sum(normal_value_mask) == 0:
|
|
159
|
+
return 0
|
|
160
|
+
else:
|
|
161
|
+
masked_ae = torch.where(normal_value_mask, abs_err, 0)
|
|
162
|
+
mse = torch.sum(torch.square(masked_ae)) / torch.sum(normal_value_mask)
|
|
163
|
+
rmse = torch.sqrt(mse)
|
|
164
|
+
return rmse
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def compute_error_balance(out_device, out_bench):
|
|
168
|
+
larger_count = torch.sum(torch.greater(out_device - out_bench.to(out_device.dtype), 0))
|
|
169
|
+
smaller_count = torch.sum(torch.less(out_device - out_bench.to(out_device.dtype), 0))
|
|
170
|
+
if torch.numel(out_bench) == 0:
|
|
171
|
+
raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}")
|
|
172
|
+
error_balance = abs(larger_count - smaller_count) / torch.numel(out_bench)
|
|
173
|
+
return error_balance
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def compare_tensor(out_device, out_bench, api_name):
|
|
177
|
+
if out_device.shape != out_bench.shape:
|
|
178
|
+
print("ERROR: shape of out_device and out_bench is not equal!")
|
|
179
|
+
return None
|
|
180
|
+
if torch.numel(out_bench) == 0:
|
|
181
|
+
print("Both out_device and out_bench have zero elements.")
|
|
182
|
+
return None
|
|
183
|
+
dtype_device = out_device.dtype
|
|
184
|
+
dtype_bench = out_bench.dtype
|
|
185
|
+
headers = ["Metric", "Value"]
|
|
186
|
+
table = [
|
|
187
|
+
["Shape", out_bench.shape],
|
|
188
|
+
["Dtype of out_device", out_device.dtype],
|
|
189
|
+
["Dtype of out_bench", out_bench.dtype]
|
|
190
|
+
]
|
|
191
|
+
if str(dtype_device) in TORCH_FLOAT_TYPE and str(dtype_bench) in TORCH_FLOAT_TYPE \
|
|
192
|
+
or str(dtype_device) in TORCH_INT_TYPE and str(dtype_bench) in TORCH_INT_TYPE \
|
|
193
|
+
or str(dtype_device) in TORCH_BOOL_TYPE and str(dtype_bench) in TORCH_BOOL_TYPE:
|
|
194
|
+
out_device = out_device.to(torch.device("cpu"))
|
|
195
|
+
if str(dtype_device) in TORCH_BOOL_TYPE or str(dtype_device) in TORCH_INT_TYPE or compare_standard == CompareStandard.BINARY_EQUALITY_STANDARD:
|
|
196
|
+
error_number = torch.sum(out_device != out_bench).item()
|
|
197
|
+
if torch.numel(out_bench) == 0:
|
|
198
|
+
raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}")
|
|
199
|
+
error_rate = error_number / torch.numel(out_bench)
|
|
200
|
+
table.append(["Compare Standard", "Binary Equality Standard"])
|
|
201
|
+
table.append(["Error Rate", error_rate])
|
|
202
|
+
else:
|
|
203
|
+
abs_err = torch.abs(out_device - out_bench)
|
|
204
|
+
abs_bench = torch.abs(out_bench)
|
|
205
|
+
if dtype_bench == torch.float32:
|
|
206
|
+
eps = 2 ** -23
|
|
207
|
+
if dtype_bench == torch.float64:
|
|
208
|
+
eps = 2 ** -52
|
|
209
|
+
abs_bench_with_eps = abs_bench + eps
|
|
210
|
+
rel_err = torch.abs(abs_err / abs_bench_with_eps)
|
|
211
|
+
device_finite_mask = torch.isfinite(out_device)
|
|
212
|
+
bench_finite_mask = torch.isfinite(out_bench.to(dtype_device))
|
|
213
|
+
both_finite_mask = torch.logical_and(device_finite_mask, bench_finite_mask)
|
|
214
|
+
inf_nan_mask = torch.logical_not(both_finite_mask)
|
|
215
|
+
if compare_standard == CompareStandard.ABSOLUTE_THRESHOLD_STANDARD:
|
|
216
|
+
if dtype_device == torch.float16:
|
|
217
|
+
rtol, small_value, small_value_atol = 1.0e-3, 1.0e-3, 1.0e-5
|
|
218
|
+
elif dtype_device == torch.bfloat16:
|
|
219
|
+
rtol, small_value, small_value_atol = 4.0e-3, 1.0e-3, 1.0e-5
|
|
220
|
+
else:
|
|
221
|
+
rtol, small_value, small_value_atol = 1.0e-6, 1.0e-6, 1.0e-9
|
|
222
|
+
small_value_mask = torch.less_equal(abs_bench, small_value)
|
|
223
|
+
small_value_mask = torch.logical_and(small_value_mask, both_finite_mask)
|
|
224
|
+
normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask))
|
|
225
|
+
inf_nan_proportion = compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol)
|
|
226
|
+
rel_err_mask = torch.greater(rel_err, rtol)
|
|
227
|
+
rel_err_mask = torch.logical_and(rel_err_mask, normal_value_mask)
|
|
228
|
+
if torch.sum(normal_value_mask) == 0:
|
|
229
|
+
rel_err_proportion = 0
|
|
230
|
+
else:
|
|
231
|
+
rel_err_proportion = torch.sum(rel_err_mask) / torch.sum(normal_value_mask)
|
|
232
|
+
abs_err_mask = torch.greater(abs_err, small_value_atol)
|
|
233
|
+
abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask)
|
|
234
|
+
if torch.sum(small_value_mask) == 0:
|
|
235
|
+
abs_err_proportion = 0
|
|
236
|
+
else:
|
|
237
|
+
abs_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask)
|
|
238
|
+
table.append(["Compare Standard", "Absolute Threshold Standard"])
|
|
239
|
+
table.append(["Relative Error Ratio", rel_err_proportion])
|
|
240
|
+
table.append(["Absolute Error Ratio", abs_err_proportion])
|
|
241
|
+
elif compare_standard == CompareStandard.ULP_ERROR_STANDARD:
|
|
242
|
+
if dtype_device == torch.float16:
|
|
243
|
+
min_eb, exponent_num = -14, 10
|
|
244
|
+
elif dtype_device == torch.bfloat16:
|
|
245
|
+
min_eb, exponent_num = -126, 7
|
|
246
|
+
else:
|
|
247
|
+
min_eb, exponent_num = -126, 23
|
|
248
|
+
eb = torch.where(abs_bench == 0, torch.zeros(out_bench.shape), torch.floor(torch.log2(abs_bench)))
|
|
249
|
+
eb = torch.maximum(eb, min_eb * torch.ones(out_bench.shape))
|
|
250
|
+
if dtype_device == torch.float32:
|
|
251
|
+
ulp_err = (out_device.to(torch.float64) - out_bench).to(torch.float64) * torch.exp2(-eb + exponent_num).to(torch.float64)
|
|
252
|
+
else:
|
|
253
|
+
ulp_err = (out_device.to(torch.float32) - out_bench).to(torch.float32) * torch.exp2(-eb + exponent_num).to(torch.float32)
|
|
254
|
+
ulp_err = torch.abs(ulp_err)
|
|
255
|
+
max_ulp_err = torch.max(ulp_err)
|
|
256
|
+
mean_ulp_err = torch.mean(ulp_err)
|
|
257
|
+
if torch.numel(out_bench) == 0:
|
|
258
|
+
raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}")
|
|
259
|
+
if dtype_device == torch.float32:
|
|
260
|
+
ulp_err_proportion = torch.sum(ulp_err > 32) / torch.numel(out_bench)
|
|
261
|
+
else:
|
|
262
|
+
ulp_err_proportion = torch.sum(ulp_err > 1) / torch.numel(out_bench)
|
|
263
|
+
table.append(["Compare Standard", "ULP error Standard"])
|
|
264
|
+
table.append(["Maximum ULP Error", max_ulp_err])
|
|
265
|
+
table.append(["Mean ULP Error", mean_ulp_err])
|
|
266
|
+
table.append(["ULP Error Proportion", ulp_err_proportion])
|
|
267
|
+
elif compare_standard == CompareStandard.THOUSANDTH_STANDARD:
|
|
268
|
+
rel_err_origin = torch.abs(abs_err / abs_bench_with_eps)
|
|
269
|
+
if torch.numel(rel_err_origin) == 0:
|
|
270
|
+
thousand_res = 1
|
|
271
|
+
else:
|
|
272
|
+
thousand_res = torch.divide(torch.sum(rel_err < THOUSANDTH_THRESHOLDING), torch.numel(rel_err_origin))
|
|
273
|
+
thousand_status = thousand_res > (1 - THOUSANDTH_THRESHOLDING)
|
|
274
|
+
table.append(["Compare Standard", "Thousandth Standard"])
|
|
275
|
+
table.append(["Thousandth ratio", thousand_res])
|
|
276
|
+
else:
|
|
277
|
+
if dtype_device == torch.float16:
|
|
278
|
+
small_value, small_value_atol = 1.0e-3, 1.0e-5
|
|
279
|
+
elif dtype_device == torch.bfloat16:
|
|
280
|
+
small_value, small_value_atol = 1.0e-3, 1.0e-5
|
|
281
|
+
else:
|
|
282
|
+
small_value, small_value_atol = 1.0e-6, 1.0e-9
|
|
283
|
+
small_value_mask = torch.less_equal(abs_bench, small_value)
|
|
284
|
+
small_value_mask = torch.logical_and(small_value_mask, both_finite_mask)
|
|
285
|
+
normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask))
|
|
286
|
+
abs_err_mask = torch.greater(abs_err, small_value_atol)
|
|
287
|
+
abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask)
|
|
288
|
+
if torch.sum(small_value_mask) == 0:
|
|
289
|
+
small_value_err_proportion = 0
|
|
290
|
+
else:
|
|
291
|
+
small_value_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask)
|
|
292
|
+
rel_err = torch.where(normal_value_mask, rel_err, -1 * torch.ones(out_device.shape))
|
|
293
|
+
if torch.max(rel_err) >= 0:
|
|
294
|
+
max_rel_err = torch.max(rel_err)
|
|
295
|
+
else:
|
|
296
|
+
max_rel_err = 0
|
|
297
|
+
if torch.sum(normal_value_mask) == 0:
|
|
298
|
+
mean_rel_err = 0
|
|
299
|
+
else:
|
|
300
|
+
mean_rel_err = torch.sum(torch.clamp(rel_err, min=0)) / torch.sum(normal_value_mask)
|
|
301
|
+
rmse = compute_rmse(abs_err, normal_value_mask)
|
|
302
|
+
error_balance = compute_error_balance(out_device, out_bench)
|
|
303
|
+
table.append(["Compare Standard", "Benchmark Standard"])
|
|
304
|
+
table.append(["Small Value Error Proportion", small_value_err_proportion])
|
|
305
|
+
table.append(["Maximum Relative Error", max_rel_err])
|
|
306
|
+
table.append(["Mean Relative Error", mean_rel_err])
|
|
307
|
+
table.append(["Root Mean Squared Error", rmse])
|
|
308
|
+
table.append(["Error Balance", error_balance])
|
|
309
|
+
else:
|
|
310
|
+
print(f"ERROR: out_device dtype is {{dtype_device}}, out_bench dtype is {{dtype_bench}}, not comparable.")
|
|
311
|
+
return None
|
|
312
|
+
print(tabulate(table, headers, tablefmt='grid'))
|
|
313
|
+
return None
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def compare_element(out_device, out_bench, api_name):
|
|
317
|
+
if type(out_device) != type(out_bench):
|
|
318
|
+
print("ERROR: out_device and out_bench is not the same type!")
|
|
319
|
+
return None
|
|
320
|
+
if isinstance(out_bench, torch.Tensor):
|
|
321
|
+
compare_tensor(out_device, out_bench, api_name)
|
|
322
|
+
elif isinstance(out_bench, (bool, int, float, str)):
|
|
323
|
+
if out_device == out_bench:
|
|
324
|
+
print("PASS: out_device and out_bench equals.")
|
|
325
|
+
else:
|
|
326
|
+
print("ERROR: out_device and out_bench is not equal!")
|
|
327
|
+
else:
|
|
328
|
+
print(f"ERROR: comparison of type {{type(out_bench)}} is not supported.")
|
|
329
|
+
return None
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def compare(out_device, out_bench, api_name):
|
|
333
|
+
print("Compare result:")
|
|
334
|
+
if type(out_device) != type(out_bench):
|
|
335
|
+
print("ERROR: out_device and out_bench is not the same type!")
|
|
336
|
+
return None
|
|
337
|
+
if isinstance(out_bench, (list, tuple)):
|
|
338
|
+
if len(out_device) != len(out_bench):
|
|
339
|
+
print("ERROR: len of out_device and out_bench is different!")
|
|
340
|
+
return None
|
|
341
|
+
for index, _ in enumerate(out_bench):
|
|
342
|
+
print(f"index {{index}}:")
|
|
343
|
+
compare_element(out_device[index], out_bench[index], api_name)
|
|
344
|
+
else:
|
|
345
|
+
compare_element(out_device, out_bench, api_name)
|
|
346
|
+
|
|
347
|
+
if __name__ == "__main__":
|
|
348
|
+
device = get_device()
|
|
349
|
+
api_name = "{api_name}"
|
|
350
|
+
propagation = "{propagation}"
|
|
351
|
+
compare_standard = {compare_standard}
|
|
352
|
+
torch.manual_seed({random_seed})
|
|
353
|
+
for i in range({iter_times}):
|
|
354
|
+
print(f"iter: {{i}}:")
|
|
355
|
+
if propagation == BACKWARD:
|
|
356
|
+
args_device, kwargs_device, args_bench, kwargs_bench, args_device_backward, args_bench_backward = get_input(propagation)
|
|
357
|
+
output_device = exec_api(args_device, kwargs_device, args_device_backward, propagation)
|
|
358
|
+
output_bench = exec_api(args_bench, kwargs_bench, args_bench_backward, propagation)
|
|
359
|
+
compare(output_device, output_bench, api_name)
|
|
360
|
+
else:
|
|
361
|
+
args_device, kwargs_device, args_bench, kwargs_bench = get_input(propagation)
|
|
362
|
+
output_device = exec_api(args_device, kwargs_device, None, propagation)
|
|
363
|
+
output_bench = exec_api(args_bench, kwargs_bench, None, propagation)
|
|
364
|
+
compare(output_device, output_bench, api_name)
|
|
365
|
+
print("Compare finished.")
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
2
|
# -*- coding: utf-8 -*-
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
#
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
7
|
# you may not use this file except in compliance with the License.
|
|
7
8
|
# You may obtain a copy of the License at
|
|
8
9
|
#
|
|
@@ -13,7 +14,6 @@
|
|
|
13
14
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
15
|
# See the License for the specific language governing permissions and
|
|
15
16
|
# limitations under the License.
|
|
16
|
-
"""
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
19
|
import math
|
|
@@ -22,19 +22,28 @@ import numpy
|
|
|
22
22
|
|
|
23
23
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
|
|
24
24
|
from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \
|
|
25
|
-
CompareException
|
|
25
|
+
CompareException, get_module_and_atttribute_name, get_attribute
|
|
26
26
|
from msprobe.core.common.file_utils import FileChecker, load_npy
|
|
27
27
|
from msprobe.pytorch.common.log import logger
|
|
28
28
|
from msprobe.pytorch.common.utils import load_pt
|
|
29
|
-
from msprobe.core.common.const import Const, FileCheckConst
|
|
29
|
+
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
30
30
|
|
|
31
31
|
TORCH_TYPE = ["torch.device", "torch.dtype"]
|
|
32
32
|
TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
|
|
33
|
-
FLOAT_TYPE = [
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
33
|
+
FLOAT_TYPE = [
|
|
34
|
+
'torch.float32',
|
|
35
|
+
'torch.float',
|
|
36
|
+
'torch.float64',
|
|
37
|
+
'torch.double',
|
|
38
|
+
'torch.float16',
|
|
39
|
+
'torch.half',
|
|
40
|
+
'torch.bfloat16'
|
|
41
|
+
]
|
|
42
|
+
NUMPY_TYPE = [
|
|
43
|
+
"numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32",
|
|
44
|
+
"numpy.uint64", "numpy.float16", "numpy.float32", "numpy.float64", "numpy.float128", "numpy.complex64",
|
|
45
|
+
"numpy.complex128", "numpy.complex256", "numpy.bool_", "numpy.string_", "numpy.bytes_", "numpy.unicode_"
|
|
46
|
+
]
|
|
38
47
|
|
|
39
48
|
|
|
40
49
|
def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
|
|
@@ -68,7 +77,8 @@ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
|
|
|
68
77
|
raise Exception("{} is not supported now".format(data_type))
|
|
69
78
|
data = info.get("value")
|
|
70
79
|
try:
|
|
71
|
-
|
|
80
|
+
module_name, attribute_name = get_module_and_atttribute_name(data_type)
|
|
81
|
+
data = get_attribute(module_name, attribute_name)(data)
|
|
72
82
|
except Exception as err:
|
|
73
83
|
logger.error("Failed to convert the type to numpy: %s" % str(err))
|
|
74
84
|
elif data_type == "torch.Size":
|
|
@@ -104,8 +114,9 @@ def gen_real_tensor(data_path, convert_type):
|
|
|
104
114
|
if convert_type:
|
|
105
115
|
ori_dtype = Const.CONVERT.get(convert_type)[0]
|
|
106
116
|
dist_dtype = Const.CONVERT.get(convert_type)[1]
|
|
117
|
+
module_name, attribute_name = get_module_and_atttribute_name(dist_dtype)
|
|
107
118
|
if str(data.dtype) == ori_dtype:
|
|
108
|
-
data = data.type(
|
|
119
|
+
data = data.type(get_attribute(module_name, attribute_name))
|
|
109
120
|
return data
|
|
110
121
|
|
|
111
122
|
|
|
@@ -118,13 +129,22 @@ def gen_random_tensor(info, convert_type):
|
|
|
118
129
|
convert_type: convert ori_type to dist_type flag.
|
|
119
130
|
"""
|
|
120
131
|
check_object_type(info, dict)
|
|
121
|
-
|
|
122
|
-
low_origin
|
|
132
|
+
|
|
133
|
+
low_origin = info.get('Min')
|
|
134
|
+
low = info.get('Min_except_inf_nan', low_origin)
|
|
135
|
+
high_origin = info.get('Max')
|
|
136
|
+
high = info.get('Max_except_inf_nan', high_origin)
|
|
137
|
+
|
|
123
138
|
low_info = [low, low_origin]
|
|
124
139
|
high_info = [high, high_origin]
|
|
125
140
|
data_dtype = info.get('dtype')
|
|
126
141
|
shape = tuple(info.get('shape'))
|
|
127
|
-
if
|
|
142
|
+
if 0 in shape:
|
|
143
|
+
low, low_origin = 0, 0
|
|
144
|
+
high, high_origin = 0, 0
|
|
145
|
+
low_info = [low, low_origin]
|
|
146
|
+
high_info = [high, high_origin]
|
|
147
|
+
elif not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
|
|
128
148
|
error_info = f'Data info Min: {low} , Max: {high}, info type must be int or float.'
|
|
129
149
|
raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
|
|
130
150
|
if data_dtype == "torch.bool":
|
|
@@ -164,33 +184,35 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type):
|
|
|
164
184
|
data_dtype = Const.CONVERT.get(convert_type)[1]
|
|
165
185
|
low, low_origin = low_info[0], low_info[1]
|
|
166
186
|
high, high_origin = high_info[0], high_info[1]
|
|
167
|
-
|
|
187
|
+
module_name, attribute_name = get_module_and_atttribute_name(data_dtype)
|
|
188
|
+
dtype = get_attribute(module_name, attribute_name)
|
|
189
|
+
if data_dtype in FLOAT_TYPE:
|
|
168
190
|
if math.isnan(high):
|
|
169
|
-
tensor = torch.
|
|
191
|
+
tensor = torch.full(shape, float('nan'), dtype=dtype)
|
|
170
192
|
return tensor
|
|
171
193
|
#high_origin为新版json中的属性,只有当high_origin不为None,且high为inf或-inf时,原tensor全为inf或-inf
|
|
172
|
-
if high_origin and high in [float(
|
|
173
|
-
tensor = torch.
|
|
194
|
+
if high_origin and high in [float(CompareConst.INF), float(CompareConst.NEG_INF)]:
|
|
195
|
+
tensor = torch.full(shape, high, dtype=dtype)
|
|
174
196
|
tensor[-1] = low
|
|
175
197
|
return tensor
|
|
176
198
|
low_scale, high_scale = low, high
|
|
177
|
-
dtype_finfo = torch.finfo(
|
|
199
|
+
dtype_finfo = torch.finfo(dtype)
|
|
178
200
|
#适配老版json high和low为inf或-inf的情况,取dtype的最大值或最小值进行放缩
|
|
179
|
-
if high == float(
|
|
201
|
+
if high == float(CompareConst.INF):
|
|
180
202
|
high_scale = dtype_finfo.max
|
|
181
|
-
elif high == float(
|
|
203
|
+
elif high == float(CompareConst.NEG_INF):
|
|
182
204
|
high_scale = dtype_finfo.min
|
|
183
|
-
if low == float(
|
|
205
|
+
if low == float(CompareConst.INF):
|
|
184
206
|
low_scale = dtype_finfo.max
|
|
185
|
-
elif low == float(
|
|
207
|
+
elif low == float(CompareConst.NEG_INF):
|
|
186
208
|
low_scale = dtype_finfo.min
|
|
187
209
|
|
|
188
210
|
scale = high_scale - low_scale
|
|
189
|
-
rand01 = torch.rand(shape, dtype=
|
|
211
|
+
rand01 = torch.rand(shape, dtype=dtype)
|
|
190
212
|
tensor = rand01 * scale + low_scale
|
|
191
213
|
elif 'int' in data_dtype or 'long' in data_dtype:
|
|
192
214
|
low, high = int(low), int(high)
|
|
193
|
-
tensor = torch.randint(low, high + 1, shape, dtype=
|
|
215
|
+
tensor = torch.randint(low, high + 1, shape, dtype=dtype)
|
|
194
216
|
else:
|
|
195
217
|
logger.error('Dtype is not supported: ' + data_dtype)
|
|
196
218
|
raise NotImplementedError()
|
|
@@ -208,9 +230,9 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type):
|
|
|
208
230
|
else:
|
|
209
231
|
tmp_tensor[0] = low
|
|
210
232
|
tmp_tensor[-1] = high
|
|
211
|
-
if high_origin in [float(
|
|
233
|
+
if high_origin in [float(CompareConst.INF), float(CompareConst.NEG_INF)]:
|
|
212
234
|
tmp_tensor[-1] = high_origin
|
|
213
|
-
if low_origin in [float(
|
|
235
|
+
if low_origin in [float(CompareConst.INF), float(CompareConst.NEG_INF)]:
|
|
214
236
|
tmp_tensor[0] = low_origin
|
|
215
237
|
data = tmp_tensor.reshape(shape)
|
|
216
238
|
return data
|
|
@@ -233,7 +255,7 @@ def gen_bool_tensor(low, high, shape):
|
|
|
233
255
|
return data
|
|
234
256
|
|
|
235
257
|
|
|
236
|
-
def gen_args(args_info, api_name,
|
|
258
|
+
def gen_args(args_info, api_name, func_options):
|
|
237
259
|
"""
|
|
238
260
|
Function Description:
|
|
239
261
|
Based on API basic information, generate input parameters: args, for API forward running
|
|
@@ -246,9 +268,20 @@ def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_p
|
|
|
246
268
|
"""
|
|
247
269
|
check_object_type(args_info, list)
|
|
248
270
|
args_result = []
|
|
271
|
+
|
|
272
|
+
need_grad = func_options.get('need_grad', True)
|
|
273
|
+
convert_type = func_options.get('convert_type', None)
|
|
274
|
+
real_data_path = func_options.get('real_data_path', None)
|
|
275
|
+
depth = func_options.get('depth', 0)
|
|
276
|
+
|
|
277
|
+
if depth > Const.MAX_DEPTH:
|
|
278
|
+
logger.error("The depth of args is too large, please check the input args.")
|
|
279
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
280
|
+
|
|
249
281
|
for arg in args_info:
|
|
250
282
|
if isinstance(arg, (list, tuple)):
|
|
251
|
-
|
|
283
|
+
func_options['depth'] = depth + 1
|
|
284
|
+
data = gen_args(arg, api_name, func_options)
|
|
252
285
|
elif isinstance(arg, dict):
|
|
253
286
|
data = gen_data(arg, api_name, need_grad, convert_type, real_data_path)
|
|
254
287
|
elif arg is None:
|
|
@@ -288,7 +321,8 @@ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
|
|
|
288
321
|
|
|
289
322
|
def gen_torch_kwargs(kwargs_params, key, value):
|
|
290
323
|
if value.get('type') != "torch.device":
|
|
291
|
-
|
|
324
|
+
module_name, attribute_name = get_module_and_atttribute_name(value.get('value'))
|
|
325
|
+
kwargs_params[key] = get_attribute(module_name, attribute_name)
|
|
292
326
|
|
|
293
327
|
|
|
294
328
|
def gen_list_kwargs(kwargs_item_value, api_name, convert_type, real_data_path=None):
|
|
@@ -327,8 +361,14 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d
|
|
|
327
361
|
error_info = f"convert_type params not support {convert_type}."
|
|
328
362
|
raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
|
|
329
363
|
kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path)
|
|
364
|
+
func_options = {
|
|
365
|
+
'need_grad': need_grad,
|
|
366
|
+
'convert_type': convert_type,
|
|
367
|
+
'real_data_path': real_data_path,
|
|
368
|
+
'depth': 0
|
|
369
|
+
}
|
|
330
370
|
if api_info.get("input_args"):
|
|
331
|
-
args_params = gen_args(api_info.get("input_args"), api_name,
|
|
371
|
+
args_params = gen_args(api_info.get("input_args"), api_name, func_options)
|
|
332
372
|
else:
|
|
333
373
|
logger.warning(f'Warning: No args in {api_info} ')
|
|
334
374
|
args_params = []
|