mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -1,15 +1,27 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
2
15
|
|
|
3
16
|
import mindspore
|
|
4
17
|
import torch
|
|
5
18
|
from mindspore import ops
|
|
6
|
-
|
|
7
|
-
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
8
19
|
from msprobe.core.common.const import Const, MsCompareConst
|
|
9
20
|
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
10
|
-
from msprobe.mindspore.
|
|
11
|
-
from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
|
|
21
|
+
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
12
22
|
from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str
|
|
23
|
+
from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
|
|
24
|
+
from msprobe.mindspore.common.log import logger
|
|
13
25
|
|
|
14
26
|
|
|
15
27
|
class ApiInputAggregation:
|
|
@@ -24,11 +36,23 @@ class ApiInputAggregation:
|
|
|
24
36
|
self.kwargs = kwargs
|
|
25
37
|
self.gradient_inputs = gradient_inputs
|
|
26
38
|
|
|
39
|
+
|
|
27
40
|
api_parent_module_mapping = {
|
|
28
41
|
(MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint,
|
|
29
42
|
(MsCompareConst.MINT, Const.PT_FRAMEWORK): torch,
|
|
30
43
|
(MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
|
|
31
|
-
(MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional
|
|
44
|
+
(MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional,
|
|
45
|
+
(MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): mindspore.Tensor,
|
|
46
|
+
(MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
api_parent_module_str_mapping = {
|
|
50
|
+
(MsCompareConst.MINT, Const.MS_FRAMEWORK): "mindspore.mint",
|
|
51
|
+
(MsCompareConst.MINT, Const.PT_FRAMEWORK): "torch",
|
|
52
|
+
(MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): "mindspore.mint.nn.functional",
|
|
53
|
+
(MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): "torch.nn.functional",
|
|
54
|
+
(MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): "mindspore.Tensor",
|
|
55
|
+
(MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor"
|
|
32
56
|
}
|
|
33
57
|
|
|
34
58
|
|
|
@@ -60,7 +84,7 @@ class ApiRunner:
|
|
|
60
84
|
api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
|
|
61
85
|
|
|
62
86
|
Return:
|
|
63
|
-
api_type_str: str, Union["MintFunctional", "Mint"]
|
|
87
|
+
api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
|
|
64
88
|
api_sub_name: str, e.g. "relu"
|
|
65
89
|
'''
|
|
66
90
|
api_name_list = api_name_str.split(Const.SEP)
|
|
@@ -68,8 +92,8 @@ class ApiRunner:
|
|
|
68
92
|
err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
|
|
69
93
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
70
94
|
api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
|
|
71
|
-
if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL]:
|
|
72
|
-
err_msg = f"ApiRunner.get_info_from_name failed: not mint
|
|
95
|
+
if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API]:
|
|
96
|
+
err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api"
|
|
73
97
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
74
98
|
|
|
75
99
|
return api_type_str, api_sub_name
|
|
@@ -78,7 +102,7 @@ class ApiRunner:
|
|
|
78
102
|
def get_api_instance(api_type_str, api_sub_name, api_platform):
|
|
79
103
|
'''
|
|
80
104
|
Args:
|
|
81
|
-
api_type_str: str, Union["MintFunctional", "Mint"]
|
|
105
|
+
api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
|
|
82
106
|
api_sub_name: str, e.g. "relu"
|
|
83
107
|
api_platform: str: Union["mindpore", "torch"]
|
|
84
108
|
|
|
@@ -92,9 +116,8 @@ class ApiRunner:
|
|
|
92
116
|
'''
|
|
93
117
|
|
|
94
118
|
api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
full_api_name = module_str + submodule_str + api_sub_name
|
|
119
|
+
api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform))
|
|
120
|
+
full_api_name = api_parent_module_str + Const.SEP + api_sub_name
|
|
98
121
|
if not hasattr(api_parent_module, api_sub_name):
|
|
99
122
|
err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
|
|
100
123
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
|
|
@@ -115,7 +138,7 @@ class ApiRunner:
|
|
|
115
138
|
gradient_inputs = api_input_aggregation.gradient_inputs
|
|
116
139
|
|
|
117
140
|
if forward_or_backward == Const.FORWARD:
|
|
118
|
-
forward_result = api_instance(*inputs, **kwargs)
|
|
141
|
+
forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
|
|
119
142
|
forward_result_tuple = convert_to_tuple(forward_result)
|
|
120
143
|
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
|
|
121
144
|
else:
|
|
@@ -127,18 +150,20 @@ class ApiRunner:
|
|
|
127
150
|
if api_platform == Const.MS_FRAMEWORK:
|
|
128
151
|
if len(gradient_inputs) == 1:
|
|
129
152
|
gradient_inputs = gradient_inputs[0]
|
|
153
|
+
|
|
130
154
|
def api_with_kwargs(*forward_inputs):
|
|
131
155
|
return api_instance(*forward_inputs, **kwargs)
|
|
156
|
+
|
|
132
157
|
grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs)
|
|
133
|
-
backward_result = grad_func(*inputs, gradient_inputs)
|
|
158
|
+
backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
|
|
134
159
|
backward_result_tuple = convert_to_tuple(backward_result)
|
|
135
160
|
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
|
|
136
161
|
else:
|
|
137
|
-
#set requires_grad
|
|
162
|
+
# set requires_grad
|
|
138
163
|
requires_grad_index = []
|
|
139
164
|
for index, tensor in enumerate(inputs):
|
|
140
165
|
if isinstance(tensor, torch.Tensor) and \
|
|
141
|
-
|
|
166
|
+
torch_dtype_to_dtype_str.get(tensor.dtype) in float_dtype_str_list:
|
|
142
167
|
setattr(tensor, "requires_grad", True)
|
|
143
168
|
requires_grad_index.append(index)
|
|
144
169
|
forward_results = api_instance(*inputs, **kwargs)
|
|
@@ -153,4 +178,4 @@ class ApiRunner:
|
|
|
153
178
|
return res_compute_element_list
|
|
154
179
|
|
|
155
180
|
|
|
156
|
-
api_runner = ApiRunner()
|
|
181
|
+
api_runner = ApiRunner()
|
|
@@ -1,12 +1,27 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from abc import ABC, abstractmethod
|
|
2
17
|
|
|
3
18
|
import mindspore
|
|
4
|
-
import torch
|
|
5
19
|
import numpy as np
|
|
6
|
-
|
|
20
|
+
import torch
|
|
21
|
+
from msprobe.core.common.const import CompareConst, MsCompareConst
|
|
7
22
|
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
8
23
|
from msprobe.mindspore.common.log import logger
|
|
9
|
-
|
|
24
|
+
|
|
10
25
|
|
|
11
26
|
class CompareResult:
|
|
12
27
|
def __init__(self, compare_value, pass_status, err_msg):
|
|
@@ -28,7 +43,7 @@ class BaseCompareAlgorithm(ABC):
|
|
|
28
43
|
CompareConst.MAX_ABS_ERR: {
|
|
29
44
|
CompareConst.PASS: "",
|
|
30
45
|
CompareConst.ERROR: "max absolute difference is greater than " \
|
|
31
|
-
|
|
46
|
+
f"threshold: {CompareConst.MAX_ABS_ERR_THRESHOLD} ",
|
|
32
47
|
CompareConst.SKIP: "two inputs are not valid for computing max absolute difference, skip comparing ",
|
|
33
48
|
},
|
|
34
49
|
CompareConst.MAX_RELATIVE_ERR: {
|
|
@@ -68,7 +83,7 @@ class BaseCompareAlgorithm(ABC):
|
|
|
68
83
|
ndarray = tensor.to(torch.float64, copy=True).numpy()
|
|
69
84
|
else:
|
|
70
85
|
err_msg = "BaseCompareAlgorithm.convert_to_np_float64_ndarray failed: " \
|
|
71
|
-
|
|
86
|
+
"input is not mindspore.Tensor or torch.Tensor"
|
|
72
87
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
73
88
|
return ndarray
|
|
74
89
|
|
|
@@ -189,9 +204,8 @@ class MaxRelativeDiffCompareAlgorithm(BaseCompareAlgorithm):
|
|
|
189
204
|
return CompareConst.ERROR
|
|
190
205
|
|
|
191
206
|
|
|
192
|
-
|
|
193
207
|
compare_algorithms = {
|
|
194
208
|
CompareConst.COSINE: CosineSimilarityCompareAlgorithm(),
|
|
195
209
|
CompareConst.MAX_ABS_ERR: MaxAbsoluteDiffCompareAlgorithm(),
|
|
196
210
|
CompareConst.MAX_RELATIVE_ERR: MaxRelativeDiffCompareAlgorithm(),
|
|
197
|
-
}
|
|
211
|
+
}
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
# list of api that can be checked
|
|
17
|
+
|
|
18
|
+
tensor:
|
|
19
|
+
- add_
|
|
20
|
+
- add
|
|
21
|
+
- addmm_
|
|
22
|
+
- all
|
|
23
|
+
- allclose
|
|
24
|
+
- any
|
|
25
|
+
- bool
|
|
26
|
+
- byte
|
|
27
|
+
- ceil
|
|
28
|
+
- clamp
|
|
29
|
+
- contiguous
|
|
30
|
+
- copy_
|
|
31
|
+
- cos
|
|
32
|
+
- clone
|
|
33
|
+
- cumprod
|
|
34
|
+
- expand_as
|
|
35
|
+
- flatten
|
|
36
|
+
- float
|
|
37
|
+
- half
|
|
38
|
+
- int
|
|
39
|
+
- is_contiguous
|
|
40
|
+
- isnan
|
|
41
|
+
- item
|
|
42
|
+
- log
|
|
43
|
+
- log2
|
|
44
|
+
- long
|
|
45
|
+
- masked_fill
|
|
46
|
+
- max
|
|
47
|
+
- mean
|
|
48
|
+
- min
|
|
49
|
+
- numel
|
|
50
|
+
- numpy
|
|
51
|
+
- repeat
|
|
52
|
+
- repeat_interleave
|
|
53
|
+
- reshape
|
|
54
|
+
- round
|
|
55
|
+
- select
|
|
56
|
+
- sin
|
|
57
|
+
- size
|
|
58
|
+
- split
|
|
59
|
+
- sqrt
|
|
60
|
+
- square
|
|
61
|
+
- sub
|
|
62
|
+
- swapaxes
|
|
63
|
+
- to
|
|
64
|
+
- t
|
|
65
|
+
- tolist
|
|
66
|
+
- topk
|
|
67
|
+
- transpose
|
|
68
|
+
- trunc
|
|
69
|
+
- type
|
|
70
|
+
- unsqueeze
|
|
71
|
+
- view
|
|
72
|
+
- view_as
|
|
73
|
+
- fill_
|
|
74
|
+
- floor_
|
|
75
|
+
- clamp_
|
|
76
|
+
- type_as
|
|
77
|
+
- zero_
|
|
@@ -1,6 +1,68 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, create_directory
|
|
21
|
+
from msprobe.core.common.utils import Const, MsprobeBaseException
|
|
22
|
+
|
|
23
|
+
class UniqueDeviceAction(argparse.Action):
|
|
24
|
+
def __call__(self, parser, namespace, values, option_string=None):
|
|
25
|
+
unique_values = set(values)
|
|
26
|
+
if len(values) != len(unique_values):
|
|
27
|
+
parser.error("device id must be unique")
|
|
28
|
+
for device_id in values:
|
|
29
|
+
if not 0 <= device_id <= 4095:
|
|
30
|
+
parser.error(f"the argument 'device_id' must be in range [0, 4095], but got {device_id}")
|
|
31
|
+
setattr(namespace, self.dest, values)
|
|
32
|
+
|
|
33
|
+
|
|
1
34
|
def add_api_accuracy_checker_argument(parser):
|
|
2
35
|
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
|
|
3
36
|
help="<Required> The api param tool result file: generate from api param tool, "
|
|
4
37
|
"a json file.")
|
|
5
38
|
parser.add_argument("-o", "--out_path", dest="out_path", default="./", type=str, required=False,
|
|
6
|
-
help="<optional> The ut task result out path.")
|
|
39
|
+
help="<optional> The ut task result out path.")
|
|
40
|
+
parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
|
|
41
|
+
help="<optional> the exit csv for continue")
|
|
42
|
+
|
|
43
|
+
def multi_add_api_accuracy_checker_argument(parser):
|
|
44
|
+
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
|
|
45
|
+
help="<Required> The api param tool result file: generate from api param tool, "
|
|
46
|
+
"a json file.")
|
|
47
|
+
parser.add_argument("-o", "--out_path", dest="out_path", default="./", type=str, required=False,
|
|
48
|
+
help="<optional> The ut task result out path.")
|
|
49
|
+
parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
|
|
50
|
+
help="<optional> the exit csv for continue")
|
|
51
|
+
#以下属于多线程参数
|
|
52
|
+
parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int,
|
|
53
|
+
help="<optional> set device id to run ut, must be unique and in range 0-7",
|
|
54
|
+
default=[0], required=False, action=UniqueDeviceAction)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def check_args(args):
|
|
58
|
+
args.api_info_file = os.path.abspath(args.api_info_file)
|
|
59
|
+
check_file_or_directory_path(args.api_info_file)
|
|
60
|
+
|
|
61
|
+
if args.out_path == "":
|
|
62
|
+
args.out_path = "./"
|
|
63
|
+
args.out_path = os.path.abspath(args.out_path)
|
|
64
|
+
create_directory(args.out_path)
|
|
65
|
+
|
|
66
|
+
if args.result_csv_path:
|
|
67
|
+
args.result_csv_path = os.path.abspath(args.result_csv_path)
|
|
68
|
+
check_file_or_directory_path(args.result_csv_path)
|
|
@@ -1,21 +1,37 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
|
|
3
18
|
import mindspore
|
|
4
|
-
import torch
|
|
5
19
|
import numpy as np
|
|
6
|
-
|
|
7
|
-
from
|
|
20
|
+
import torch
|
|
21
|
+
from mindspore._c_expression import typing
|
|
22
|
+
from msprobe.core.common.const import Const
|
|
8
23
|
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
9
24
|
from msprobe.core.common.file_utils import load_npy
|
|
10
|
-
from msprobe.mindspore.api_accuracy_checker.type_mapping import (
|
|
25
|
+
from msprobe.mindspore.api_accuracy_checker.type_mapping import (api_info_type_str_to_type,
|
|
11
26
|
ms_dtype_to_dtype_str, torch_dtype_to_dtype_str,
|
|
12
27
|
dtype_str_to_ms_dtype, dtype_str_to_np_dtype,
|
|
13
28
|
dtype_str_to_torch_dtype, type_to_api_info_type_str,
|
|
14
29
|
DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE, TUPLE_TYPE_STR,
|
|
15
|
-
MINDSPORE_TENSOR_TYPE_STR,
|
|
16
|
-
|
|
17
|
-
|
|
30
|
+
MINDSPORE_TENSOR_TYPE_STR, MINDSPORE_DTYPE_TYPE_STR,
|
|
31
|
+
SLICE_TYPE_STR, TORCH_DTYPE_TYPE_STR,
|
|
32
|
+
float_dtype_str_list, int_dtype_str_list)
|
|
18
33
|
from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
|
|
34
|
+
from msprobe.mindspore.common.log import logger
|
|
19
35
|
|
|
20
36
|
|
|
21
37
|
class MstensorMetaData:
|
|
@@ -26,6 +42,12 @@ class MstensorMetaData:
|
|
|
26
42
|
self.minimum = minimum
|
|
27
43
|
self.shape = shape
|
|
28
44
|
|
|
45
|
+
|
|
46
|
+
class DtypeMetaData:
|
|
47
|
+
def __init__(self, dtype_str) -> None:
|
|
48
|
+
self.dtype_str = dtype_str
|
|
49
|
+
|
|
50
|
+
|
|
29
51
|
class ComputeElement:
|
|
30
52
|
def __init__(self, compute_element_info=None, parameter=None):
|
|
31
53
|
self.supported_parameter_type = tuple(type_to_api_info_type_str.keys()) + tuple([torch.Tensor, tuple])
|
|
@@ -118,6 +140,11 @@ class ComputeElement:
|
|
|
118
140
|
for compute_element in self.parameter])
|
|
119
141
|
elif isinstance(self.parameter, self.supported_parameter_type):
|
|
120
142
|
parameter_tmp = self.parameter
|
|
143
|
+
elif isinstance(self.parameter, DtypeMetaData):
|
|
144
|
+
if tensor_platform == Const.MS_FRAMEWORK:
|
|
145
|
+
parameter_tmp = dtype_str_to_ms_dtype.get(self.parameter.dtype_str)
|
|
146
|
+
else:
|
|
147
|
+
parameter_tmp = dtype_str_to_torch_dtype.get(self.parameter.dtype_str)
|
|
121
148
|
elif isinstance(self.parameter, MstensorMetaData):
|
|
122
149
|
mstensor_meta_data = self.parameter
|
|
123
150
|
ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
|
|
@@ -130,13 +157,13 @@ class ComputeElement:
|
|
|
130
157
|
parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype)
|
|
131
158
|
else:
|
|
132
159
|
err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \
|
|
133
|
-
|
|
160
|
+
"(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)"
|
|
134
161
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
135
162
|
|
|
136
163
|
# if necessary, do transfer
|
|
137
164
|
if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
|
|
138
165
|
parameter = self.transfer_to_torch_tensor(parameter_tmp)
|
|
139
|
-
elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform ==Const.MS_FRAMEWORK:
|
|
166
|
+
elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform == Const.MS_FRAMEWORK:
|
|
140
167
|
parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
|
|
141
168
|
else:
|
|
142
169
|
parameter = parameter_tmp
|
|
@@ -183,34 +210,38 @@ class ComputeElement:
|
|
|
183
210
|
else:
|
|
184
211
|
type_str = check_and_get_from_json_dict(compute_element_info, "type", "type field in api_info.json",
|
|
185
212
|
accepted_type=str, accepted_value=api_info_type_str_to_type.keys())
|
|
186
|
-
|
|
213
|
+
self.shape = tuple()
|
|
214
|
+
self.dtype_str = type_str
|
|
187
215
|
if type_str == MINDSPORE_TENSOR_TYPE_STR:
|
|
188
216
|
self._init_from_mstensor_compute_element_info(compute_element_info)
|
|
189
|
-
else:
|
|
217
|
+
else:
|
|
190
218
|
value = check_and_get_from_json_dict(compute_element_info, "value", "value field in api_info.json")
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
219
|
+
if type_str == MINDSPORE_DTYPE_TYPE_STR:
|
|
220
|
+
self.parameter = DtypeMetaData(value)
|
|
221
|
+
elif type_str == SLICE_TYPE_STR:
|
|
222
|
+
self.parameter = slice(*tuple(value))
|
|
223
|
+
else: # type_str in ("str", "int", "float", "bool")
|
|
224
|
+
self.parameter = value
|
|
194
225
|
|
|
195
226
|
def _init_from_mstensor_compute_element_info(self, compute_element_info):
|
|
196
227
|
'''
|
|
197
228
|
do not load real tensor, only record meta data
|
|
198
229
|
'''
|
|
199
230
|
dtype_str = check_and_get_from_json_dict(compute_element_info, "dtype", "dtype field in api_info.json",
|
|
200
|
-
|
|
231
|
+
accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys())
|
|
201
232
|
shape = check_and_get_from_json_dict(compute_element_info, "shape", "shape field in api_info.json",
|
|
202
|
-
|
|
233
|
+
accepted_type=(list,))
|
|
203
234
|
if global_context.get_is_constructed():
|
|
204
235
|
maximum = check_and_get_from_json_dict(compute_element_info, "Max", "Max field in api_info.json",
|
|
205
|
-
|
|
236
|
+
accepted_type=(int, float))
|
|
206
237
|
minimum = check_and_get_from_json_dict(compute_element_info, "Min", "Min field in api_info.json",
|
|
207
|
-
|
|
238
|
+
accepted_type=(int, float))
|
|
208
239
|
|
|
209
240
|
npy_path = None
|
|
210
241
|
else:
|
|
211
242
|
maximum, minimum = None, None
|
|
212
243
|
data_name = check_and_get_from_json_dict(compute_element_info, "data_name",
|
|
213
|
-
|
|
244
|
+
"data_name field in api_info.json", accepted_type=(str,))
|
|
214
245
|
npy_path = os.path.join(global_context.get_dump_data_dir(), data_name)
|
|
215
246
|
mstensor_meta_data = MstensorMetaData(dtype_str, npy_path, maximum, minimum, shape)
|
|
216
247
|
self.parameter = mstensor_meta_data
|
|
@@ -219,9 +250,10 @@ class ComputeElement:
|
|
|
219
250
|
|
|
220
251
|
def _init_with_parameter(self, parameter):
|
|
221
252
|
self.parameter = parameter
|
|
253
|
+
self.shape = tuple()
|
|
222
254
|
if not isinstance(parameter, self.supported_parameter_type):
|
|
223
255
|
err_msg = "ComputeElement._init_with_parameter failed: " \
|
|
224
|
-
|
|
256
|
+
"parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)"
|
|
225
257
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
226
258
|
if isinstance(parameter, mindspore.Tensor):
|
|
227
259
|
self.shape = tuple(parameter.shape)
|
|
@@ -229,11 +261,14 @@ class ComputeElement:
|
|
|
229
261
|
elif isinstance(parameter, torch.Tensor):
|
|
230
262
|
self.shape = tuple(parameter.shape)
|
|
231
263
|
self.dtype_str = torch_dtype_to_dtype_str.get(parameter.dtype)
|
|
264
|
+
elif isinstance(parameter, typing.Type):
|
|
265
|
+
self.dtype_str = MINDSPORE_DTYPE_TYPE_STR
|
|
266
|
+
self.parameter = DtypeMetaData(ms_dtype_to_dtype_str.get(parameter))
|
|
267
|
+
elif isinstance(parameter, torch.dtype):
|
|
268
|
+
self.dtype_str = TORCH_DTYPE_TYPE_STR
|
|
269
|
+
self.parameter = DtypeMetaData(torch_dtype_to_dtype_str.get(parameter))
|
|
232
270
|
elif isinstance(parameter, tuple):
|
|
233
|
-
self.shape = tuple()
|
|
234
271
|
self.dtype_str = TUPLE_TYPE_STR
|
|
235
272
|
self.parameter = tuple([ComputeElement(parameter=param) for param in parameter])
|
|
236
273
|
else:
|
|
237
|
-
self.
|
|
238
|
-
self.dtype_str = \
|
|
239
|
-
TUPLE_TYPE_STR if isinstance(parameter, tuple) else type_to_api_info_type_str.get(type(parameter))
|
|
274
|
+
self.dtype_str = type_to_api_info_type_str.get(type(parameter))
|