mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
# list of api that can be checked
|
|
17
|
+
|
|
18
|
+
tensor:
|
|
19
|
+
- add_
|
|
20
|
+
- add
|
|
21
|
+
- addmm_
|
|
22
|
+
- all
|
|
23
|
+
- allclose
|
|
24
|
+
- any
|
|
25
|
+
- bool
|
|
26
|
+
- byte
|
|
27
|
+
- ceil
|
|
28
|
+
- clamp
|
|
29
|
+
- contiguous
|
|
30
|
+
- copy_
|
|
31
|
+
- cos
|
|
32
|
+
- clone
|
|
33
|
+
- cumprod
|
|
34
|
+
- expand_as
|
|
35
|
+
- flatten
|
|
36
|
+
- float
|
|
37
|
+
- half
|
|
38
|
+
- int
|
|
39
|
+
- is_contiguous
|
|
40
|
+
- isnan
|
|
41
|
+
- item
|
|
42
|
+
- log
|
|
43
|
+
- log2
|
|
44
|
+
- long
|
|
45
|
+
- masked_fill
|
|
46
|
+
- max
|
|
47
|
+
- mean
|
|
48
|
+
- min
|
|
49
|
+
- numel
|
|
50
|
+
- numpy
|
|
51
|
+
- repeat
|
|
52
|
+
- repeat_interleave
|
|
53
|
+
- reshape
|
|
54
|
+
- round
|
|
55
|
+
- select
|
|
56
|
+
- sin
|
|
57
|
+
- size
|
|
58
|
+
- split
|
|
59
|
+
- sqrt
|
|
60
|
+
- square
|
|
61
|
+
- sub
|
|
62
|
+
- swapaxes
|
|
63
|
+
- to
|
|
64
|
+
- t
|
|
65
|
+
- tolist
|
|
66
|
+
- topk
|
|
67
|
+
- transpose
|
|
68
|
+
- trunc
|
|
69
|
+
- type
|
|
70
|
+
- unsqueeze
|
|
71
|
+
- view
|
|
72
|
+
- view_as
|
|
73
|
+
- fill_
|
|
74
|
+
- floor_
|
|
75
|
+
- clamp_
|
|
76
|
+
- type_as
|
|
77
|
+
- zero_
|
|
@@ -1,6 +1,69 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, create_directory
|
|
20
|
+
from msprobe.core.common.utils import Const, MsprobeBaseException
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class UniqueDeviceAction(argparse.Action):
|
|
24
|
+
def __call__(self, parser, namespace, values, option_string=None):
|
|
25
|
+
unique_values = set(values)
|
|
26
|
+
if len(values) != len(unique_values):
|
|
27
|
+
parser.error("device id must be unique")
|
|
28
|
+
for device_id in values:
|
|
29
|
+
if not 0 <= device_id <= 4095:
|
|
30
|
+
parser.error(f"the argument 'device_id' must be in range [0, 4095], but got {device_id}")
|
|
31
|
+
setattr(namespace, self.dest, values)
|
|
32
|
+
|
|
33
|
+
|
|
1
34
|
def add_api_accuracy_checker_argument(parser):
|
|
2
35
|
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
|
|
3
36
|
help="<Required> The api param tool result file: generate from api param tool, "
|
|
4
37
|
"a json file.")
|
|
5
38
|
parser.add_argument("-o", "--out_path", dest="out_path", default="./", type=str, required=False,
|
|
6
|
-
help="<optional> The ut task result out path.")
|
|
39
|
+
help="<optional> The ut task result out path.")
|
|
40
|
+
parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
|
|
41
|
+
help="<optional> the exit csv for continue")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def multi_add_api_accuracy_checker_argument(parser):
|
|
45
|
+
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
|
|
46
|
+
help="<Required> The api param tool result file: generate from api param tool, "
|
|
47
|
+
"a json file.")
|
|
48
|
+
parser.add_argument("-o", "--out_path", dest="out_path", default="./", type=str, required=False,
|
|
49
|
+
help="<optional> The ut task result out path.")
|
|
50
|
+
parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
|
|
51
|
+
help="<optional> the exit csv for continue")
|
|
52
|
+
#以下属于多线程参数
|
|
53
|
+
parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int,
|
|
54
|
+
help="<optional> set device id to run ut, must be unique and in range 0-7",
|
|
55
|
+
default=[0], required=False, action=UniqueDeviceAction)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def check_args(args):
|
|
59
|
+
args.api_info_file = os.path.abspath(args.api_info_file)
|
|
60
|
+
check_file_or_directory_path(args.api_info_file)
|
|
61
|
+
|
|
62
|
+
if args.out_path == "":
|
|
63
|
+
args.out_path = "./"
|
|
64
|
+
args.out_path = os.path.abspath(args.out_path)
|
|
65
|
+
create_directory(args.out_path)
|
|
66
|
+
|
|
67
|
+
if args.result_csv_path:
|
|
68
|
+
args.result_csv_path = os.path.abspath(args.result_csv_path)
|
|
69
|
+
check_file_or_directory_path(args.result_csv_path)
|
|
@@ -1,21 +1,37 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
|
|
3
18
|
import mindspore
|
|
4
|
-
import torch
|
|
5
19
|
import numpy as np
|
|
6
|
-
|
|
7
|
-
from
|
|
20
|
+
import torch
|
|
21
|
+
from mindspore._c_expression import typing
|
|
22
|
+
from msprobe.core.common.const import Const
|
|
8
23
|
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
9
24
|
from msprobe.core.common.file_utils import load_npy
|
|
10
|
-
from msprobe.mindspore.api_accuracy_checker.type_mapping import (
|
|
25
|
+
from msprobe.mindspore.api_accuracy_checker.type_mapping import (api_info_type_str_to_type,
|
|
11
26
|
ms_dtype_to_dtype_str, torch_dtype_to_dtype_str,
|
|
12
27
|
dtype_str_to_ms_dtype, dtype_str_to_np_dtype,
|
|
13
28
|
dtype_str_to_torch_dtype, type_to_api_info_type_str,
|
|
14
29
|
DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE, TUPLE_TYPE_STR,
|
|
15
|
-
MINDSPORE_TENSOR_TYPE_STR,
|
|
16
|
-
|
|
17
|
-
|
|
30
|
+
MINDSPORE_TENSOR_TYPE_STR, MINDSPORE_DTYPE_TYPE_STR,
|
|
31
|
+
SLICE_TYPE_STR, TORCH_DTYPE_TYPE_STR,
|
|
32
|
+
float_dtype_str_list, int_dtype_str_list)
|
|
18
33
|
from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
|
|
34
|
+
from msprobe.mindspore.common.log import logger
|
|
19
35
|
|
|
20
36
|
|
|
21
37
|
class MstensorMetaData:
|
|
@@ -26,6 +42,12 @@ class MstensorMetaData:
|
|
|
26
42
|
self.minimum = minimum
|
|
27
43
|
self.shape = shape
|
|
28
44
|
|
|
45
|
+
|
|
46
|
+
class DtypeMetaData:
|
|
47
|
+
def __init__(self, dtype_str) -> None:
|
|
48
|
+
self.dtype_str = dtype_str
|
|
49
|
+
|
|
50
|
+
|
|
29
51
|
class ComputeElement:
|
|
30
52
|
def __init__(self, compute_element_info=None, parameter=None):
|
|
31
53
|
self.supported_parameter_type = tuple(type_to_api_info_type_str.keys()) + tuple([torch.Tensor, tuple])
|
|
@@ -56,12 +78,10 @@ class ComputeElement:
|
|
|
56
78
|
else:
|
|
57
79
|
torch_dtype = dtype_str_to_torch_dtype.get(dtype_str)
|
|
58
80
|
|
|
59
|
-
if dtype_str in
|
|
60
|
-
middle_dtype = mindspore.float64
|
|
61
|
-
elif dtype_str in int_dtype_str_list:
|
|
81
|
+
if dtype_str in int_dtype_str_list:
|
|
62
82
|
middle_dtype = mindspore.int64
|
|
63
83
|
else:
|
|
64
|
-
middle_dtype = mindspore.
|
|
84
|
+
middle_dtype = mindspore.float64
|
|
65
85
|
np_ndarray = ms_tensor.astype(middle_dtype).numpy()
|
|
66
86
|
torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype)
|
|
67
87
|
return torch_tensor
|
|
@@ -84,10 +104,10 @@ class ComputeElement:
|
|
|
84
104
|
else:
|
|
85
105
|
ms_dtype = dtype_str_to_ms_dtype.get(dtype_str)
|
|
86
106
|
|
|
87
|
-
if dtype_str in
|
|
88
|
-
middle_dtype = torch.float64
|
|
89
|
-
elif dtype_str in int_dtype_str_list:
|
|
107
|
+
if dtype_str in int_dtype_str_list:
|
|
90
108
|
middle_dtype = torch.int64
|
|
109
|
+
else:
|
|
110
|
+
middle_dtype = torch.float64
|
|
91
111
|
np_ndarray = torch_tensor.to(middle_dtype, copy=True).numpy()
|
|
92
112
|
ms_tensor = mindspore.Tensor.from_numpy(np_ndarray).astype(ms_dtype)
|
|
93
113
|
return ms_tensor
|
|
@@ -118,6 +138,11 @@ class ComputeElement:
|
|
|
118
138
|
for compute_element in self.parameter])
|
|
119
139
|
elif isinstance(self.parameter, self.supported_parameter_type):
|
|
120
140
|
parameter_tmp = self.parameter
|
|
141
|
+
elif isinstance(self.parameter, DtypeMetaData):
|
|
142
|
+
if tensor_platform == Const.MS_FRAMEWORK:
|
|
143
|
+
parameter_tmp = dtype_str_to_ms_dtype.get(self.parameter.dtype_str)
|
|
144
|
+
else:
|
|
145
|
+
parameter_tmp = dtype_str_to_torch_dtype.get(self.parameter.dtype_str)
|
|
121
146
|
elif isinstance(self.parameter, MstensorMetaData):
|
|
122
147
|
mstensor_meta_data = self.parameter
|
|
123
148
|
ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
|
|
@@ -130,13 +155,13 @@ class ComputeElement:
|
|
|
130
155
|
parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype)
|
|
131
156
|
else:
|
|
132
157
|
err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \
|
|
133
|
-
|
|
158
|
+
"(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)"
|
|
134
159
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
135
160
|
|
|
136
161
|
# if necessary, do transfer
|
|
137
162
|
if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
|
|
138
163
|
parameter = self.transfer_to_torch_tensor(parameter_tmp)
|
|
139
|
-
elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform ==Const.MS_FRAMEWORK:
|
|
164
|
+
elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform == Const.MS_FRAMEWORK:
|
|
140
165
|
parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
|
|
141
166
|
else:
|
|
142
167
|
parameter = parameter_tmp
|
|
@@ -183,34 +208,38 @@ class ComputeElement:
|
|
|
183
208
|
else:
|
|
184
209
|
type_str = check_and_get_from_json_dict(compute_element_info, "type", "type field in api_info.json",
|
|
185
210
|
accepted_type=str, accepted_value=api_info_type_str_to_type.keys())
|
|
186
|
-
|
|
211
|
+
self.shape = tuple()
|
|
212
|
+
self.dtype_str = type_str
|
|
187
213
|
if type_str == MINDSPORE_TENSOR_TYPE_STR:
|
|
188
214
|
self._init_from_mstensor_compute_element_info(compute_element_info)
|
|
189
|
-
else:
|
|
215
|
+
else:
|
|
190
216
|
value = check_and_get_from_json_dict(compute_element_info, "value", "value field in api_info.json")
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
217
|
+
if type_str == MINDSPORE_DTYPE_TYPE_STR:
|
|
218
|
+
self.parameter = DtypeMetaData(value)
|
|
219
|
+
elif type_str == SLICE_TYPE_STR:
|
|
220
|
+
self.parameter = slice(*tuple(value))
|
|
221
|
+
else: # type_str in ("str", "int", "float", "bool")
|
|
222
|
+
self.parameter = value
|
|
194
223
|
|
|
195
224
|
def _init_from_mstensor_compute_element_info(self, compute_element_info):
|
|
196
225
|
'''
|
|
197
226
|
do not load real tensor, only record meta data
|
|
198
227
|
'''
|
|
199
228
|
dtype_str = check_and_get_from_json_dict(compute_element_info, "dtype", "dtype field in api_info.json",
|
|
200
|
-
|
|
229
|
+
accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys())
|
|
201
230
|
shape = check_and_get_from_json_dict(compute_element_info, "shape", "shape field in api_info.json",
|
|
202
|
-
|
|
231
|
+
accepted_type=(list,))
|
|
203
232
|
if global_context.get_is_constructed():
|
|
204
233
|
maximum = check_and_get_from_json_dict(compute_element_info, "Max", "Max field in api_info.json",
|
|
205
|
-
|
|
234
|
+
accepted_type=(int, float))
|
|
206
235
|
minimum = check_and_get_from_json_dict(compute_element_info, "Min", "Min field in api_info.json",
|
|
207
|
-
|
|
236
|
+
accepted_type=(int, float))
|
|
208
237
|
|
|
209
238
|
npy_path = None
|
|
210
239
|
else:
|
|
211
240
|
maximum, minimum = None, None
|
|
212
241
|
data_name = check_and_get_from_json_dict(compute_element_info, "data_name",
|
|
213
|
-
|
|
242
|
+
"data_name field in api_info.json", accepted_type=(str,))
|
|
214
243
|
npy_path = os.path.join(global_context.get_dump_data_dir(), data_name)
|
|
215
244
|
mstensor_meta_data = MstensorMetaData(dtype_str, npy_path, maximum, minimum, shape)
|
|
216
245
|
self.parameter = mstensor_meta_data
|
|
@@ -219,9 +248,10 @@ class ComputeElement:
|
|
|
219
248
|
|
|
220
249
|
def _init_with_parameter(self, parameter):
|
|
221
250
|
self.parameter = parameter
|
|
251
|
+
self.shape = tuple()
|
|
222
252
|
if not isinstance(parameter, self.supported_parameter_type):
|
|
223
253
|
err_msg = "ComputeElement._init_with_parameter failed: " \
|
|
224
|
-
|
|
254
|
+
"parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)"
|
|
225
255
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
226
256
|
if isinstance(parameter, mindspore.Tensor):
|
|
227
257
|
self.shape = tuple(parameter.shape)
|
|
@@ -229,11 +259,14 @@ class ComputeElement:
|
|
|
229
259
|
elif isinstance(parameter, torch.Tensor):
|
|
230
260
|
self.shape = tuple(parameter.shape)
|
|
231
261
|
self.dtype_str = torch_dtype_to_dtype_str.get(parameter.dtype)
|
|
262
|
+
elif isinstance(parameter, typing.Type):
|
|
263
|
+
self.dtype_str = MINDSPORE_DTYPE_TYPE_STR
|
|
264
|
+
self.parameter = DtypeMetaData(ms_dtype_to_dtype_str.get(parameter))
|
|
265
|
+
elif isinstance(parameter, torch.dtype):
|
|
266
|
+
self.dtype_str = TORCH_DTYPE_TYPE_STR
|
|
267
|
+
self.parameter = DtypeMetaData(torch_dtype_to_dtype_str.get(parameter))
|
|
232
268
|
elif isinstance(parameter, tuple):
|
|
233
|
-
self.shape = tuple()
|
|
234
269
|
self.dtype_str = TUPLE_TYPE_STR
|
|
235
270
|
self.parameter = tuple([ComputeElement(parameter=param) for param in parameter])
|
|
236
271
|
else:
|
|
237
|
-
self.
|
|
238
|
-
self.dtype_str = \
|
|
239
|
-
TUPLE_TYPE_STR if isinstance(parameter, tuple) else type_to_api_info_type_str.get(type(parameter))
|
|
272
|
+
self.dtype_str = type_to_api_info_type_str.get(type(parameter))
|
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import csv
|
|
18
|
+
|
|
19
|
+
from msprobe.core.common.const import Const, CompareConst, MsCompareConst
|
|
20
|
+
from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv
|
|
21
|
+
from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException
|
|
22
|
+
from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
|
|
23
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
24
|
+
from msprobe.mindspore.common.log import logger
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ResultCsvEntry:
|
|
28
|
+
def __init__(self) -> None:
|
|
29
|
+
self.forward_pass_status = None
|
|
30
|
+
self.backward_pass_status = None
|
|
31
|
+
self.forward_err_msg = ""
|
|
32
|
+
self.backward_err_msg = ""
|
|
33
|
+
self.overall_err_msg = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def write_csv_header(csv_path, header_func):
|
|
37
|
+
"""如果是第一次写入,则写入 CSV 表头"""
|
|
38
|
+
header = header_func() # 获取表头
|
|
39
|
+
logger.debug(f"Writing CSV header: {header}")
|
|
40
|
+
write_csv([header], csv_path, mode="a+")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_result_csv_header():
|
|
44
|
+
"""获取结果 CSV 文件的表头"""
|
|
45
|
+
return [
|
|
46
|
+
MsCompareConst.DETAIL_CSV_API_NAME,
|
|
47
|
+
MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
|
|
48
|
+
MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
|
|
49
|
+
MsCompareConst.DETAIL_CSV_MESSAGE,
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_detail_csv_header():
|
|
54
|
+
"""获取详细 CSV 文件的表头"""
|
|
55
|
+
detail_csv_header_basic_info = [
|
|
56
|
+
MsCompareConst.DETAIL_CSV_API_NAME,
|
|
57
|
+
MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
|
|
58
|
+
MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
|
|
59
|
+
MsCompareConst.DETAIL_CSV_SHAPE,
|
|
60
|
+
]
|
|
61
|
+
detail_csv_header_compare_result = list(compare_algorithms.keys())
|
|
62
|
+
detail_csv_header_status = [
|
|
63
|
+
MsCompareConst.DETAIL_CSV_PASS_STATUS,
|
|
64
|
+
MsCompareConst.DETAIL_CSV_MESSAGE,
|
|
65
|
+
]
|
|
66
|
+
return detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def check_csv_header(headers, required_constants, csv_path):
|
|
70
|
+
"""校验 CSV 文件表头是否包含所有必需的常量"""
|
|
71
|
+
missing_constants = [const for const in required_constants if not any(const in header for header in headers)]
|
|
72
|
+
|
|
73
|
+
if missing_constants:
|
|
74
|
+
raise MsprobeBaseException(
|
|
75
|
+
MsprobeBaseException.MISSING_HEADER_ERROR,
|
|
76
|
+
f"{csv_path} 缺少以下必需的表头字段: {missing_constants}"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class DataManager:
|
|
81
|
+
def __init__(self, csv_dir, result_csv_path):
|
|
82
|
+
self.results = {}
|
|
83
|
+
self.results_exception_skip = {}
|
|
84
|
+
self.is_first_write = True # 标记用于添加表头
|
|
85
|
+
self.csv_dir = csv_dir
|
|
86
|
+
self.api_names_set = set() # 存储已经出现的 API 名称的集合
|
|
87
|
+
# 如果传入了 result_csv_path,则启用断点续检
|
|
88
|
+
if result_csv_path:
|
|
89
|
+
self.resume_from_last_csv(result_csv_path)
|
|
90
|
+
self.initialize_api_names_set(result_csv_path)
|
|
91
|
+
else:
|
|
92
|
+
# 默认情况下,设置输出路径为空,等待首次写入时初始化
|
|
93
|
+
self.result_out_path = os.path.join(self.csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
|
|
94
|
+
self.detail_out_path = os.path.join(
|
|
95
|
+
self.csv_dir,
|
|
96
|
+
os.path.basename(self.result_out_path).replace("result", "details")
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if self.detail_out_path and os.path.exists(self.detail_out_path):
|
|
100
|
+
check_file_or_directory_path(self.detail_out_path)
|
|
101
|
+
|
|
102
|
+
if self.result_out_path and os.path.exists(self.result_out_path):
|
|
103
|
+
check_file_or_directory_path(self.result_out_path)
|
|
104
|
+
|
|
105
|
+
def initialize_api_names_set(self, result_csv_path):
|
|
106
|
+
"""读取现有的 CSV 文件并存储已经出现的 API 名称到集合中"""
|
|
107
|
+
# 使用新的 read_csv 函数读取数据
|
|
108
|
+
csv_data = read_csv(result_csv_path, as_pd=False)
|
|
109
|
+
|
|
110
|
+
# 读取标题行
|
|
111
|
+
headers = csv_data[0] if csv_data else [] # 如果文件为空,则 headers 会为空
|
|
112
|
+
|
|
113
|
+
# 使用提取的表头校验函数
|
|
114
|
+
if check_csv_header(headers, get_result_csv_header(), result_csv_path):
|
|
115
|
+
|
|
116
|
+
# 获取 "API Name" 列的索引
|
|
117
|
+
api_name_index = None
|
|
118
|
+
for i, header in enumerate(headers):
|
|
119
|
+
if MsCompareConst.DETAIL_CSV_API_NAME in header: # CSV 文件的标题行包含了字节顺序标记,所以使用通过包含方式来查找
|
|
120
|
+
api_name_index = i
|
|
121
|
+
break
|
|
122
|
+
|
|
123
|
+
if api_name_index is None:
|
|
124
|
+
logger.warning(f"{result_csv_path} No column contains 'API Name'.")
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
# 读取每一行的 API 名称
|
|
128
|
+
for row in csv_data[1:]: # 跳过标题行,从第二行开始
|
|
129
|
+
if row and len(row) > api_name_index:
|
|
130
|
+
api_name = row[api_name_index]
|
|
131
|
+
if api_name:
|
|
132
|
+
self.api_names_set.add(api_name)
|
|
133
|
+
|
|
134
|
+
logger.debug(f"Initialized API names set from existing CSV: {self.api_names_set}")
|
|
135
|
+
|
|
136
|
+
def is_unique_api(self, api_name):
|
|
137
|
+
"""检查 API 名称是否唯一,如果已经存在则返回 False,否则加入集合并返回 True"""
|
|
138
|
+
if api_name in self.api_names_set:
|
|
139
|
+
return False
|
|
140
|
+
self.api_names_set.add(api_name)
|
|
141
|
+
return True
|
|
142
|
+
|
|
143
|
+
def resume_from_last_csv(self, result_csv_path):
|
|
144
|
+
"""从上次运行的 result_csv_path 恢复断点"""
|
|
145
|
+
# 获取上次的目录路径
|
|
146
|
+
last_dir = os.path.dirname(result_csv_path)
|
|
147
|
+
|
|
148
|
+
# 设置当前目录和输出路径,确保在首次写入时使用
|
|
149
|
+
self.csv_dir = last_dir
|
|
150
|
+
self.detail_out_path = os.path.join(last_dir, os.path.basename(result_csv_path).replace("result", "details"))
|
|
151
|
+
if self.detail_out_path and os.path.exists(self.detail_out_path):
|
|
152
|
+
check_file_or_directory_path(self.detail_out_path)
|
|
153
|
+
self.result_out_path = result_csv_path
|
|
154
|
+
self.is_first_write = False
|
|
155
|
+
|
|
156
|
+
def save_results(self, api_name_str):
|
|
157
|
+
if self.is_first_write:
|
|
158
|
+
# 直接写入表头
|
|
159
|
+
logger.info("Writing CSV headers for the first time.")
|
|
160
|
+
write_csv_header(self.detail_out_path, get_detail_csv_header)
|
|
161
|
+
write_csv_header(self.result_out_path, get_result_csv_header)
|
|
162
|
+
self.is_first_write = False # 写入后标记为 False,避免重复写入表头
|
|
163
|
+
|
|
164
|
+
"""写入详细输出和结果摘要并清理结果"""
|
|
165
|
+
logger.debug("Starting to write detailed output to CSV.")
|
|
166
|
+
self.to_detail_csv(self.detail_out_path)
|
|
167
|
+
logger.debug(f"Detailed output for {api_name_str} written to {self.detail_out_path}.")
|
|
168
|
+
|
|
169
|
+
logger.debug("Starting to write result summary to CSV.")
|
|
170
|
+
self.to_result_csv(self.result_out_path)
|
|
171
|
+
logger.debug(f"Result summary for {api_name_str} written to {self.result_out_path}.")
|
|
172
|
+
|
|
173
|
+
# 清理记录,准备下一次调用
|
|
174
|
+
self.clear_results()
|
|
175
|
+
|
|
176
|
+
def record(self, output_list):
|
|
177
|
+
if output_list is None:
|
|
178
|
+
return
|
|
179
|
+
for output in output_list:
|
|
180
|
+
api_real_name, forward_or_backward, basic_info, compare_result_dict = output
|
|
181
|
+
key = (api_real_name, forward_or_backward)
|
|
182
|
+
if key not in self.results:
|
|
183
|
+
self.results[key] = []
|
|
184
|
+
self.results[key].append((basic_info, compare_result_dict))
|
|
185
|
+
logger.debug(f"Updated self.results for key {key}: {self.results[key]}")
|
|
186
|
+
logger.debug(f"Complete self.results after recording: {self.results}")
|
|
187
|
+
|
|
188
|
+
def record_exception_skip(self, api_name, forward_or_backward, err_msg):
|
|
189
|
+
'''
|
|
190
|
+
record exception_skip infomation into self.record_exception_skip.
|
|
191
|
+
self.record_exception_skip: dict{str: dict{"forward": str/None, "backward": str/None}}
|
|
192
|
+
string in key is api_name, string in value is err_msg
|
|
193
|
+
'''
|
|
194
|
+
if api_name not in self.results_exception_skip:
|
|
195
|
+
self.results_exception_skip[api_name] = {Const.FORWARD: None, Const.BACKWARD: None}
|
|
196
|
+
self.results_exception_skip[api_name][forward_or_backward] = err_msg
|
|
197
|
+
|
|
198
|
+
def clear_results(self):
|
|
199
|
+
"""清空 self.results 数据"""
|
|
200
|
+
logger.debug("Clearing self.results data.")
|
|
201
|
+
self.results.clear()
|
|
202
|
+
self.results_exception_skip.clear()
|
|
203
|
+
|
|
204
|
+
def to_detail_csv(self, csv_path):
|
|
205
|
+
logger.debug("Preparing detail CSV headers and rows.")
|
|
206
|
+
detail_csv = []
|
|
207
|
+
|
|
208
|
+
detail_csv_header_compare_result = list(compare_algorithms.keys())
|
|
209
|
+
|
|
210
|
+
for _, results in self.results.items():
|
|
211
|
+
for res in results:
|
|
212
|
+
basic_info, compare_result_dict = res
|
|
213
|
+
csv_row_basic_info = [
|
|
214
|
+
basic_info.api_name,
|
|
215
|
+
basic_info.bench_dtype,
|
|
216
|
+
basic_info.tested_dtype,
|
|
217
|
+
basic_info.shape
|
|
218
|
+
]
|
|
219
|
+
csv_row_compare_result = [
|
|
220
|
+
compare_result_dict.get(algorithm_name).compare_value
|
|
221
|
+
for algorithm_name in detail_csv_header_compare_result
|
|
222
|
+
]
|
|
223
|
+
csv_row_status = [basic_info.status, basic_info.err_msg]
|
|
224
|
+
csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
|
|
225
|
+
detail_csv.append(csv_row)
|
|
226
|
+
logger.debug(f"Detail CSV row added: {csv_row}")
|
|
227
|
+
|
|
228
|
+
logger.debug(f"Writing detail CSV to {csv_path}.")
|
|
229
|
+
write_csv(detail_csv, csv_path, mode="a+")
|
|
230
|
+
logger.debug(f"Detail CSV written successfully to {csv_path}.")
|
|
231
|
+
|
|
232
|
+
def to_result_csv(self, csv_path):
|
|
233
|
+
'''
|
|
234
|
+
depend on both self.results and self.results_exception_skip
|
|
235
|
+
'''
|
|
236
|
+
logger.debug("Preparing result CSV data.")
|
|
237
|
+
result_csv = []
|
|
238
|
+
|
|
239
|
+
result_csv_dict = {}
|
|
240
|
+
for key, results in self.results.items():
|
|
241
|
+
api_real_name, forward_or_backward = key
|
|
242
|
+
pass_status = CompareConst.PASS
|
|
243
|
+
overall_err_msg = ""
|
|
244
|
+
|
|
245
|
+
for res in results:
|
|
246
|
+
basic_info, _ = res
|
|
247
|
+
if basic_info.status != CompareConst.PASS:
|
|
248
|
+
pass_status = CompareConst.ERROR
|
|
249
|
+
overall_err_msg += basic_info.err_msg
|
|
250
|
+
|
|
251
|
+
overall_err_msg = "" if pass_status == CompareConst.PASS else overall_err_msg
|
|
252
|
+
|
|
253
|
+
if api_real_name not in result_csv_dict:
|
|
254
|
+
result_csv_dict[api_real_name] = ResultCsvEntry()
|
|
255
|
+
if forward_or_backward == Const.FORWARD:
|
|
256
|
+
result_csv_dict[api_real_name].forward_pass_status = pass_status
|
|
257
|
+
result_csv_dict[api_real_name].forward_err_msg = overall_err_msg
|
|
258
|
+
else:
|
|
259
|
+
result_csv_dict[api_real_name].backward_pass_status = pass_status
|
|
260
|
+
result_csv_dict[api_real_name].backward_err_msg = overall_err_msg
|
|
261
|
+
|
|
262
|
+
for api_name, entry in result_csv_dict.items():
|
|
263
|
+
overall_err_msg = "" if (entry.forward_pass_status == CompareConst.PASS and
|
|
264
|
+
entry.backward_pass_status == CompareConst.PASS) else \
|
|
265
|
+
entry.forward_err_msg + entry.backward_err_msg
|
|
266
|
+
row = [
|
|
267
|
+
api_name,
|
|
268
|
+
entry.forward_pass_status,
|
|
269
|
+
entry.backward_pass_status,
|
|
270
|
+
overall_err_msg
|
|
271
|
+
]
|
|
272
|
+
# change row if this api has excption_skip infomation
|
|
273
|
+
if api_name in self.results_exception_skip:
|
|
274
|
+
if self.results_exception_skip[api_name][Const.FORWARD] is not None:
|
|
275
|
+
row[1] = CompareConst.SKIP
|
|
276
|
+
row[-1] += self.results_exception_skip[api_name][Const.FORWARD]
|
|
277
|
+
if self.results_exception_skip[api_name][Const.BACKWARD] is not None:
|
|
278
|
+
row[2] = CompareConst.SKIP
|
|
279
|
+
row[-1] += self.results_exception_skip[api_name][Const.BACKWARD]
|
|
280
|
+
del self.results_exception_skip[api_name]
|
|
281
|
+
result_csv.append(row)
|
|
282
|
+
logger.debug(f"Result CSV row added: {row}")
|
|
283
|
+
for api_name in self.results_exception_skip:
|
|
284
|
+
current_exception_skip = self.results_exception_skip[api_name]
|
|
285
|
+
forward_status = None
|
|
286
|
+
backward_status = None
|
|
287
|
+
err_msg = ""
|
|
288
|
+
if current_exception_skip[Const.FORWARD] is not None:
|
|
289
|
+
forward_status = CompareConst.SKIP
|
|
290
|
+
err_msg += current_exception_skip[Const.FORWARD]
|
|
291
|
+
if current_exception_skip[Const.BACKWARD] is not None:
|
|
292
|
+
backward_status = CompareConst.SKIP
|
|
293
|
+
err_msg += current_exception_skip[Const.BACKWARD]
|
|
294
|
+
row = [api_name, forward_status, backward_status, err_msg]
|
|
295
|
+
result_csv.append(row)
|
|
296
|
+
|
|
297
|
+
write_csv(result_csv, csv_path, mode="a+")
|
|
298
|
+
logger.debug(f"Result CSV written successfully to {csv_path}.")
|
|
299
|
+
|
|
300
|
+
# 设置标记为 False,防止后续重复添加表头
|
|
301
|
+
self.is_first_write = False
|
|
@@ -1,9 +1,34 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker
|
|
2
17
|
|
|
18
|
+
from msprobe.mindspore.api_accuracy_checker.multi_api_accuracy_checker import MultiApiAccuracyChecker
|
|
19
|
+
|
|
20
|
+
from msprobe.mindspore.api_accuracy_checker.cmd_parser import check_args
|
|
21
|
+
|
|
3
22
|
|
|
4
23
|
def api_checker_main(args):
|
|
5
|
-
|
|
24
|
+
check_args(args)
|
|
25
|
+
api_accuracy_checker = ApiAccuracyChecker(args)
|
|
26
|
+
api_accuracy_checker.parse(args.api_info_file)
|
|
27
|
+
api_accuracy_checker.run_and_compare()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def mul_api_checker_main(args):
|
|
31
|
+
check_args(args)
|
|
32
|
+
api_accuracy_checker = MultiApiAccuracyChecker(args)
|
|
6
33
|
api_accuracy_checker.parse(args.api_info_file)
|
|
7
34
|
api_accuracy_checker.run_and_compare()
|
|
8
|
-
api_accuracy_checker.to_detail_csv(args.out_path)
|
|
9
|
-
api_accuracy_checker.to_result_csv(args.out_path)
|