mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# 动态图精度数据采集快速入门示例
|
|
2
|
+
|
|
3
|
+
本示例将展示如何在 MindSpore 动态图模式下使用 msprobe 工具进行精度数据采集。
|
|
4
|
+
|
|
5
|
+
## 1. 配置文件
|
|
6
|
+
|
|
7
|
+
请在当前目录下创建一个名为 `config.json` 的配置文件,内容如下:
|
|
8
|
+
|
|
9
|
+
```json
|
|
10
|
+
{
|
|
11
|
+
"task": "statistics",
|
|
12
|
+
"dump_path": "./output",
|
|
13
|
+
"rank": [],
|
|
14
|
+
"step": ["0-2"],
|
|
15
|
+
"level": "L1",
|
|
16
|
+
"statistics": {
|
|
17
|
+
"scope": [],
|
|
18
|
+
"list": [],
|
|
19
|
+
"data_mode": [
|
|
20
|
+
"all"
|
|
21
|
+
],
|
|
22
|
+
"summary_mode": "statistics"
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
```
|
|
27
|
+
以上配置参数详细介绍和使用请参见[《config.json 配置文件介绍》](../02.config_introduction.md)和[《config.json 配置示例》](../03.config_examples.md#3-mindspore-动态图场景) 中的“MindSpore动态图场景”。
|
|
28
|
+
|
|
29
|
+
## 2. 模型脚本
|
|
30
|
+
|
|
31
|
+
在当前目录下创建一个 Python 脚本文件,例如 `alexnet_model.py`,将以下代码粘贴进去:
|
|
32
|
+
|
|
33
|
+
```python
|
|
34
|
+
import os
|
|
35
|
+
import numpy as np
|
|
36
|
+
import mindspore as ms
|
|
37
|
+
from mindspore import nn, ops
|
|
38
|
+
from mindspore import context
|
|
39
|
+
from mindspore import Tensor
|
|
40
|
+
from msprobe.mindspore import PrecisionDebugger, seed_all
|
|
41
|
+
|
|
42
|
+
# 设置随机种子以确保结果可重现
|
|
43
|
+
seed_all(seed=1234, mode=False, rm_dropout=True)
|
|
44
|
+
|
|
45
|
+
# 配置文件路径
|
|
46
|
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
47
|
+
config_path = os.path.join(script_dir, 'config.json')
|
|
48
|
+
|
|
49
|
+
# 初始化精度调试器
|
|
50
|
+
debugger = PrecisionDebugger(config_path=config_path)
|
|
51
|
+
|
|
52
|
+
# 设置 MindSpore 设备上下文
|
|
53
|
+
context.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend", device_id=0)
|
|
54
|
+
|
|
55
|
+
# 定义卷积层
|
|
56
|
+
def conv_layer(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="valid", has_bias=True):
|
|
57
|
+
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
|
58
|
+
has_bias=has_bias, pad_mode=pad_mode)
|
|
59
|
+
|
|
60
|
+
# 定义全连接层
|
|
61
|
+
def fc_layer(input_channels, out_channels, has_bias=True):
|
|
62
|
+
return nn.Dense(input_channels, out_channels, has_bias=has_bias)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class AlexNet(nn.Cell):
|
|
66
|
+
"""
|
|
67
|
+
AlexNet 模型定义
|
|
68
|
+
|
|
69
|
+
参数:
|
|
70
|
+
- num_classes: 分类数量
|
|
71
|
+
- channel: 输入通道数(图像的颜色通道数)
|
|
72
|
+
- phase: 模型运行阶段('train' 或 'test')
|
|
73
|
+
- include_top: 是否包含全连接层的顶部(最后的分类层)
|
|
74
|
+
"""
|
|
75
|
+
def __init__(self, num_classes=10, channel=3, phase='train', include_top=True):
|
|
76
|
+
super(AlexNet, self).__init__()
|
|
77
|
+
|
|
78
|
+
# 卷积层
|
|
79
|
+
self.conv1 = conv_layer(channel, 64, 11, stride=4, pad_mode="same")
|
|
80
|
+
self.conv2 = conv_layer(64, 128, 5, pad_mode="same")
|
|
81
|
+
self.conv3 = conv_layer(128, 192, 3, pad_mode="same")
|
|
82
|
+
self.conv4 = conv_layer(192, 256, 3, pad_mode="same")
|
|
83
|
+
self.conv5 = conv_layer(256, 256, 3, pad_mode="same")
|
|
84
|
+
|
|
85
|
+
# 激活函数和池化层
|
|
86
|
+
self.relu = nn.ReLU()
|
|
87
|
+
self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')
|
|
88
|
+
|
|
89
|
+
# 如果包括顶部(全连接层)
|
|
90
|
+
self.include_top = include_top
|
|
91
|
+
if self.include_top:
|
|
92
|
+
self.flatten = nn.Flatten()
|
|
93
|
+
self.fc1 = fc_layer(256 * 28 * 28, 4096)
|
|
94
|
+
self.fc2 = fc_layer(4096, 4096)
|
|
95
|
+
self.fc3 = fc_layer(4096, num_classes)
|
|
96
|
+
|
|
97
|
+
# 数学操作
|
|
98
|
+
self.add = ops.Add()
|
|
99
|
+
self.mul = ops.Mul()
|
|
100
|
+
|
|
101
|
+
def construct(self, x):
|
|
102
|
+
"""定义前向传播过程"""
|
|
103
|
+
|
|
104
|
+
x = self.conv1(x)
|
|
105
|
+
x = self.add(x, 0.1) # 偏置加法
|
|
106
|
+
x = self.mul(x, 2.0) # 乘法操作
|
|
107
|
+
x = self.relu(x) # ReLU 激活函数
|
|
108
|
+
x = ops.celu(x)
|
|
109
|
+
x = x + 2
|
|
110
|
+
|
|
111
|
+
# 打印每层输出形状,调试时可使用
|
|
112
|
+
print(f"After Conv1: {x.shape}")
|
|
113
|
+
|
|
114
|
+
x = self.max_pool2d(x) # Max pooling 操作
|
|
115
|
+
print(f"After MaxPool: {x.shape}") # 打印池化后的形状
|
|
116
|
+
|
|
117
|
+
x = self.conv2(x)
|
|
118
|
+
x = self.relu(x)
|
|
119
|
+
|
|
120
|
+
x = self.conv3(x)
|
|
121
|
+
x = self.relu(x)
|
|
122
|
+
|
|
123
|
+
x = self.conv4(x)
|
|
124
|
+
x = self.relu(x)
|
|
125
|
+
|
|
126
|
+
x = self.conv5(x)
|
|
127
|
+
x = self.relu(x)
|
|
128
|
+
|
|
129
|
+
# 打印卷积层后的形状,调试时使用
|
|
130
|
+
print(f"After Conv5: {x.shape}")
|
|
131
|
+
|
|
132
|
+
# 可选的全连接层部分
|
|
133
|
+
if self.include_top:
|
|
134
|
+
x = self.flatten(x)
|
|
135
|
+
x = self.fc1(x)
|
|
136
|
+
x = self.fc2(x)
|
|
137
|
+
x = self.fc3(x)
|
|
138
|
+
|
|
139
|
+
return x
|
|
140
|
+
|
|
141
|
+
# 前向函数
|
|
142
|
+
def forward_fn(data, label):
|
|
143
|
+
out = net(data)
|
|
144
|
+
loss = criterion(out, label)
|
|
145
|
+
return loss
|
|
146
|
+
|
|
147
|
+
# 训练步骤
|
|
148
|
+
def train_step(data, label):
|
|
149
|
+
loss, grads = grad_fn(data, label)
|
|
150
|
+
optimizer(grads)
|
|
151
|
+
return loss
|
|
152
|
+
|
|
153
|
+
# 测试模型
|
|
154
|
+
if __name__ == "__main__":
|
|
155
|
+
net = AlexNet()
|
|
156
|
+
optimizer = nn.SGD(net.trainable_params(), learning_rate=0.01)
|
|
157
|
+
criterion = nn.MSELoss()
|
|
158
|
+
|
|
159
|
+
grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)
|
|
160
|
+
|
|
161
|
+
# 生成数据和标签
|
|
162
|
+
batch_size = 1
|
|
163
|
+
num_classes = 10
|
|
164
|
+
data = np.random.normal(1, 1, (batch_size, 3, 227, 227)).astype(np.float32)
|
|
165
|
+
label = np.random.randint(0, num_classes, (batch_size,)).astype(np.float32) # 注意此处类型应为 float32
|
|
166
|
+
|
|
167
|
+
# 转换为 MindSpore 张量
|
|
168
|
+
data = Tensor(data)
|
|
169
|
+
label = Tensor(label)
|
|
170
|
+
|
|
171
|
+
steps = 5
|
|
172
|
+
for i in range(steps):
|
|
173
|
+
debugger.start(net) # 启动调试器
|
|
174
|
+
loss = train_step(data, label) # 执行训练步骤
|
|
175
|
+
print(f"Step {i}, Loss: {loss}")
|
|
176
|
+
debugger.stop() # 停止调试器
|
|
177
|
+
debugger.step() # 计数步数
|
|
178
|
+
```
|
|
179
|
+
|
|
180
|
+
## 3. 运行训练脚本
|
|
181
|
+
|
|
182
|
+
在命令行中执行以下命令:
|
|
183
|
+
|
|
184
|
+
```bash
|
|
185
|
+
python alexnet_model.py
|
|
186
|
+
```
|
|
187
|
+
|
|
188
|
+
## 4. 查看采集结果
|
|
189
|
+
|
|
190
|
+
执行训练命令后,工具会将模型训练过程中的精度数据采集下来。
|
|
191
|
+
|
|
192
|
+
日志中打印出现如下信息表示数据采集成功,即可手动停止模型训练查看采集数据。
|
|
193
|
+
|
|
194
|
+
```markdown
|
|
195
|
+
****************************************************************************
|
|
196
|
+
* msprobe ends successfully. *
|
|
197
|
+
****************************************************************************
|
|
198
|
+
```
|
|
199
|
+
|
|
200
|
+
## 5. 数据分析
|
|
201
|
+
|
|
202
|
+
在 `dump_path` 参数指定的路径下(本例中为 `./output`),会出现如下目录结构,后续精度数据分析操作可使用 msprobe 工具的精度预检和精度比对等功能,详细流程请参见[《msprobe使用手册》](../../README.md#2-精度预检)。:
|
|
203
|
+
|
|
204
|
+
```bash
|
|
205
|
+
output/
|
|
206
|
+
└── step0
|
|
207
|
+
└── rank
|
|
208
|
+
├── construct.json # level为L0时,保存Cell的层级关系信息。当前场景为空
|
|
209
|
+
├── dump.json # 保存API前反向输入输出数据的统计量信息
|
|
210
|
+
└── stack.json # 保存API的调用栈
|
|
211
|
+
```
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
msprobe/mindspore/__init__.py
CHANGED
|
@@ -1 +1,17 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
|
|
17
|
+
from msprobe.mindspore.common.utils import seed_all
|
|
@@ -1,16 +1,34 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
2
16
|
import os
|
|
17
|
+
from tqdm import tqdm
|
|
3
18
|
|
|
4
|
-
from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv
|
|
5
|
-
from msprobe.core.common.utils import add_time_as_suffix
|
|
6
19
|
from msprobe.core.common.const import Const, CompareConst, MsCompareConst
|
|
7
|
-
from msprobe.
|
|
20
|
+
from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, load_json, load_yaml
|
|
21
|
+
from msprobe.core.common.utils import add_time_as_suffix
|
|
8
22
|
from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo
|
|
9
23
|
from msprobe.mindspore.api_accuracy_checker.api_runner import api_runner, ApiInputAggregation
|
|
10
24
|
from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
|
|
25
|
+
from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager
|
|
11
26
|
from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context,
|
|
12
27
|
trim_output_compute_element_list)
|
|
28
|
+
from msprobe.mindspore.common.log import logger
|
|
13
29
|
|
|
30
|
+
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
31
|
+
yaml_path = os.path.join(cur_path, MsCompareConst.SUPPORTED_API_LIST_FILE)
|
|
14
32
|
|
|
15
33
|
class BasicInfoAndStatus:
|
|
16
34
|
def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
|
|
@@ -21,6 +39,7 @@ class BasicInfoAndStatus:
|
|
|
21
39
|
self.status = status
|
|
22
40
|
self.err_msg = err_msg
|
|
23
41
|
|
|
42
|
+
|
|
24
43
|
class ResultCsvEntry:
|
|
25
44
|
def __init__(self) -> None:
|
|
26
45
|
self.forward_pass_status = None
|
|
@@ -31,9 +50,9 @@ class ResultCsvEntry:
|
|
|
31
50
|
|
|
32
51
|
|
|
33
52
|
class ApiAccuracyChecker:
|
|
34
|
-
def __init__(self):
|
|
53
|
+
def __init__(self, args):
|
|
35
54
|
self.api_infos = dict()
|
|
36
|
-
self.
|
|
55
|
+
self.data_manager = DataManager(args.out_path, args.result_csv_path) # 在初始化时实例化 DataManager
|
|
37
56
|
|
|
38
57
|
@staticmethod
|
|
39
58
|
def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
|
|
@@ -80,25 +99,64 @@ class ApiAccuracyChecker:
|
|
|
80
99
|
compare_result_dict[compare_algorithm_name] = compare_result
|
|
81
100
|
|
|
82
101
|
if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \
|
|
83
|
-
|
|
102
|
+
compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
|
|
84
103
|
status = CompareConst.PASS
|
|
85
104
|
err_msg = ""
|
|
86
105
|
else:
|
|
87
106
|
status = CompareConst.ERROR
|
|
88
107
|
err_msg = compare_result_dict.get(CompareConst.COSINE).err_msg + \
|
|
89
|
-
|
|
108
|
+
compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg
|
|
90
109
|
basic_info_status = \
|
|
91
110
|
BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
|
|
92
111
|
output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
|
|
93
112
|
return output_list
|
|
94
113
|
|
|
114
|
+
@staticmethod
|
|
115
|
+
def prepare_api_input_aggregation(api_info, forward_or_backward=Const.FORWARD):
|
|
116
|
+
'''
|
|
117
|
+
Args:
|
|
118
|
+
api_info: ApiInfo
|
|
119
|
+
forward_or_backward: str
|
|
120
|
+
Returns:
|
|
121
|
+
ApiInputAggregation
|
|
122
|
+
'''
|
|
123
|
+
forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
|
|
124
|
+
kwargs = api_info.get_kwargs()
|
|
125
|
+
if forward_or_backward == Const.FORWARD:
|
|
126
|
+
gradient_inputs = None
|
|
127
|
+
else:
|
|
128
|
+
gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
|
|
129
|
+
return ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
|
|
130
|
+
|
|
131
|
+
@staticmethod
|
|
132
|
+
def is_api_checkable(api_name_str):
|
|
133
|
+
'''
|
|
134
|
+
Args:
|
|
135
|
+
api_name_str: str, e.g. "MintFunctional.relu.0.forward", key in data field of api_info.json
|
|
136
|
+
Returns:
|
|
137
|
+
is_checkable: bool
|
|
138
|
+
Description:
|
|
139
|
+
tell whether this api is checkable based on the key in "data" dict in api_info.json
|
|
140
|
+
'''
|
|
141
|
+
api_name_str_list = api_name_str.split(Const.SEP)
|
|
142
|
+
if len(api_name_str_list) < MsCompareConst.API_NAME_STR_LENGTH:
|
|
143
|
+
return False
|
|
144
|
+
api_type_str = api_name_str_list[0]
|
|
145
|
+
real_api_str = Const.SEP.join(api_name_str_list[1:-2])
|
|
146
|
+
api_list = load_yaml(yaml_path)
|
|
147
|
+
supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY)
|
|
148
|
+
if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL):
|
|
149
|
+
return True
|
|
150
|
+
if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list:
|
|
151
|
+
return True
|
|
152
|
+
return False
|
|
153
|
+
|
|
95
154
|
def parse(self, api_info_path):
|
|
96
|
-
|
|
97
|
-
api_info_dict = json.load(f)
|
|
155
|
+
api_info_dict = load_json(api_info_path)
|
|
98
156
|
|
|
99
157
|
# init global context
|
|
100
158
|
task = check_and_get_from_json_dict(api_info_dict, MsCompareConst.TASK_FIELD,
|
|
101
|
-
"task field in api_info.json",accepted_type=str,
|
|
159
|
+
"task field in api_info.json", accepted_type=str,
|
|
102
160
|
accepted_value=(MsCompareConst.STATISTICS_TASK,
|
|
103
161
|
MsCompareConst.TENSOR_TASK))
|
|
104
162
|
is_constructed = task == MsCompareConst.STATISTICS_TASK
|
|
@@ -112,14 +170,12 @@ class ApiAccuracyChecker:
|
|
|
112
170
|
api_info_data = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DATA_FIELD,
|
|
113
171
|
"data field in api_info.json", accepted_type=dict)
|
|
114
172
|
for api_name, api_info in api_info_data.items():
|
|
115
|
-
|
|
116
|
-
(MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL)
|
|
117
|
-
if not is_mint:
|
|
173
|
+
if not self.is_api_checkable(api_name):
|
|
118
174
|
continue
|
|
119
175
|
forbackward_str = api_name.split(Const.SEP)[-1]
|
|
120
176
|
if forbackward_str not in (Const.FORWARD, Const.BACKWARD):
|
|
121
177
|
logger.warning(f"api: {api_name} is not recognized as forward api or backward api, skip this.")
|
|
122
|
-
api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1])
|
|
178
|
+
api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1]) # www.xxx.yyy.zzz --> www.xxx.yyy
|
|
123
179
|
if api_name not in self.api_infos:
|
|
124
180
|
self.api_infos[api_name] = ApiInfo(api_name)
|
|
125
181
|
|
|
@@ -128,128 +184,64 @@ class ApiAccuracyChecker:
|
|
|
128
184
|
else:
|
|
129
185
|
self.api_infos[api_name].load_backward_info(api_info)
|
|
130
186
|
|
|
187
|
+
def process_forward(self, api_name_str, api_info):
|
|
188
|
+
"""处理前向检查"""
|
|
189
|
+
if not api_info.check_forward_info():
|
|
190
|
+
logger.debug(f"api: {api_name_str} is lack of forward information, skip forward check.")
|
|
191
|
+
return Const.EXCEPTION_NONE
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
|
|
195
|
+
except Exception as e:
|
|
196
|
+
logger.warning(f"Exception occurs when getting inputs for {api_name_str} forward api. "
|
|
197
|
+
f"Skipping forward check. Detailed exception information: {e}.")
|
|
198
|
+
return Const.EXCEPTION_NONE
|
|
199
|
+
|
|
200
|
+
forward_output_list = None
|
|
201
|
+
try:
|
|
202
|
+
forward_output_list = self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
|
|
203
|
+
except Exception as e:
|
|
204
|
+
logger.warning(f"Exception occurs when running and comparing {api_name_str} forward api. "
|
|
205
|
+
f"Detailed exception information: {e}.")
|
|
206
|
+
return forward_output_list
|
|
207
|
+
|
|
208
|
+
def process_backward(self, api_name_str, api_info):
|
|
209
|
+
"""处理反向检查"""
|
|
210
|
+
if not api_info.check_backward_info():
|
|
211
|
+
logger.debug(f"api: {api_name_str} is lack of backward information, skipping backward check.")
|
|
212
|
+
return Const.EXCEPTION_NONE
|
|
213
|
+
|
|
214
|
+
try:
|
|
215
|
+
backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
|
|
216
|
+
except Exception as e:
|
|
217
|
+
logger.warning(f"Exception occurs when getting inputs for {api_name_str} backward api. "
|
|
218
|
+
f"Skipping backward check. Detailed exception information: {e}.")
|
|
219
|
+
return Const.EXCEPTION_NONE
|
|
220
|
+
|
|
221
|
+
backward_output_list = None
|
|
222
|
+
try:
|
|
223
|
+
backward_output_list = self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
|
|
224
|
+
except Exception as e:
|
|
225
|
+
logger.warning(f"Exception occurs when running and comparing {api_name_str} backward api. "
|
|
226
|
+
f"Detailed exception information: {e}.")
|
|
227
|
+
return backward_output_list
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
|
|
131
231
|
def run_and_compare(self):
|
|
132
|
-
for api_name_str, api_info in self.api_infos.items():
|
|
133
|
-
if not
|
|
134
|
-
logger.warning(f"api: {api_name_str} is lack of forward infomation, skip forward and backward check")
|
|
135
|
-
continue
|
|
136
|
-
forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
|
|
137
|
-
kwargs = api_info.get_kwargs()
|
|
138
|
-
forward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, None)
|
|
139
|
-
forward_output_list = None
|
|
140
|
-
try:
|
|
141
|
-
forward_output_list = \
|
|
142
|
-
self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
|
|
143
|
-
except Exception as e:
|
|
144
|
-
logger.warning(f"exception occurs when running and comparing {api_name_str} forward api"
|
|
145
|
-
f"detailed exception information: {e}")
|
|
146
|
-
self.record(forward_output_list)
|
|
147
|
-
|
|
148
|
-
if not api_info.check_backward_info():
|
|
149
|
-
logger.warning(f"api: {api_name_str} is lack of backward infomation, skip backward check")
|
|
232
|
+
for api_name_str, api_info in tqdm(self.api_infos.items()):
|
|
233
|
+
if not self.data_manager.is_unique_api(api_name_str):
|
|
150
234
|
continue
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
return
|
|
165
|
-
for output in output_list:
|
|
166
|
-
api_real_name, forward_or_backward, basic_info, compare_result_dict = output
|
|
167
|
-
key = tuple([api_real_name, forward_or_backward])
|
|
168
|
-
if key not in self.results:
|
|
169
|
-
self.results[key] = []
|
|
170
|
-
self.results[key].append(tuple([basic_info, compare_result_dict]))
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
def to_detail_csv(self, csv_dir):
|
|
174
|
-
# detail_csv
|
|
175
|
-
detail_csv = []
|
|
176
|
-
detail_csv_header_basic_info = [
|
|
177
|
-
MsCompareConst.DETAIL_CSV_API_NAME,
|
|
178
|
-
MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
|
|
179
|
-
MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
|
|
180
|
-
MsCompareConst.DETAIL_CSV_SHAPE,
|
|
181
|
-
]
|
|
182
|
-
detail_csv_header_compare_result = list(compare_algorithms.keys())
|
|
183
|
-
detail_csv_header_status = [
|
|
184
|
-
MsCompareConst.DETAIL_CSV_PASS_STATUS,
|
|
185
|
-
MsCompareConst.DETAIL_CSV_MESSAGE,
|
|
186
|
-
]
|
|
187
|
-
|
|
188
|
-
detail_csv_header = detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
|
|
189
|
-
detail_csv.append(detail_csv_header)
|
|
190
|
-
|
|
191
|
-
for _, results in self.results.items():
|
|
192
|
-
# detail csv
|
|
193
|
-
for res in results:
|
|
194
|
-
basic_info, compare_result_dict = res
|
|
195
|
-
csv_row_basic_info = \
|
|
196
|
-
[basic_info.api_name, basic_info.bench_dtype, basic_info.tested_dtype, basic_info.shape]
|
|
197
|
-
csv_row_compare_result = list(compare_result_dict.get(algorithm_name).compare_value \
|
|
198
|
-
for algorithm_name in detail_csv_header_compare_result)
|
|
199
|
-
csv_row_status = [basic_info.status, basic_info.err_msg]
|
|
200
|
-
csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
|
|
201
|
-
detail_csv.append(csv_row)
|
|
202
|
-
|
|
203
|
-
file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.DETAIL_CSV_FILE_NAME))
|
|
204
|
-
create_directory(csv_dir)
|
|
205
|
-
write_csv(detail_csv, file_name, mode="w")
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
def to_result_csv(self, csv_dir):
|
|
209
|
-
result_csv_dict = dict()
|
|
210
|
-
for key, results in self.results.items():
|
|
211
|
-
api_real_name, forward_or_backward = key
|
|
212
|
-
forward_or_backward_pass_status = CompareConst.PASS
|
|
213
|
-
forward_or_backward_overall_err_msg = ""
|
|
214
|
-
# detail csv
|
|
215
|
-
for res in results:
|
|
216
|
-
basic_info, _ = res
|
|
217
|
-
if basic_info.status != CompareConst.PASS:
|
|
218
|
-
forward_or_backward_pass_status = CompareConst.ERROR
|
|
219
|
-
forward_or_backward_overall_err_msg += basic_info.err_msg
|
|
220
|
-
forward_or_backward_overall_err_msg = \
|
|
221
|
-
"" if forward_or_backward_pass_status == CompareConst.PASS else forward_or_backward_overall_err_msg
|
|
222
|
-
|
|
223
|
-
#result_csv_dict
|
|
224
|
-
if api_real_name not in result_csv_dict:
|
|
225
|
-
result_csv_dict[api_real_name] = ResultCsvEntry()
|
|
226
|
-
if forward_or_backward == Const.FORWARD:
|
|
227
|
-
result_csv_dict[api_real_name].forward_pass_status = forward_or_backward_pass_status
|
|
228
|
-
result_csv_dict[api_real_name].forward_err_msg = forward_or_backward_overall_err_msg
|
|
229
|
-
else:
|
|
230
|
-
result_csv_dict[api_real_name].backward_pass_status = forward_or_backward_pass_status
|
|
231
|
-
result_csv_dict[api_real_name].backward_err_msg = forward_or_backward_overall_err_msg
|
|
232
|
-
|
|
233
|
-
#result_csv
|
|
234
|
-
result_csv = []
|
|
235
|
-
result_csv_header = [
|
|
236
|
-
MsCompareConst.DETAIL_CSV_API_NAME,
|
|
237
|
-
MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
|
|
238
|
-
MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
|
|
239
|
-
MsCompareConst.DETAIL_CSV_MESSAGE,
|
|
240
|
-
]
|
|
241
|
-
result_csv.append(result_csv_header)
|
|
242
|
-
|
|
243
|
-
for api_name, result_csv_entry in result_csv_dict.items():
|
|
244
|
-
if result_csv_entry.forward_pass_status == CompareConst.PASS and \
|
|
245
|
-
result_csv_entry.backward_pass_status == CompareConst.PASS:
|
|
246
|
-
overall_err_msg = ""
|
|
247
|
-
else:
|
|
248
|
-
overall_err_msg = result_csv_entry.forward_err_msg + result_csv_entry.backward_err_msg
|
|
249
|
-
row = [api_name, result_csv_entry.forward_pass_status,
|
|
250
|
-
result_csv_entry.backward_pass_status, overall_err_msg]
|
|
251
|
-
result_csv.append(row)
|
|
252
|
-
|
|
253
|
-
file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
|
|
254
|
-
create_directory(csv_dir)
|
|
255
|
-
write_csv(result_csv, file_name, mode="w")
|
|
235
|
+
|
|
236
|
+
# 处理前向
|
|
237
|
+
forward_output_list = self.process_forward(api_name_str, api_info)
|
|
238
|
+
if forward_output_list is not Const.EXCEPTION_NONE:
|
|
239
|
+
self.data_manager.record(forward_output_list)
|
|
240
|
+
|
|
241
|
+
# 处理反向
|
|
242
|
+
backward_output_list = self.process_backward(api_name_str, api_info)
|
|
243
|
+
if backward_output_list is not Const.EXCEPTION_NONE:
|
|
244
|
+
self.data_manager.record(backward_output_list)
|
|
245
|
+
|
|
246
|
+
self.data_manager.save_results(api_name_str)
|
|
247
|
+
|
|
@@ -1,11 +1,34 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
2
16
|
from msprobe.core.common.const import Const
|
|
3
|
-
from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
|
|
4
17
|
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
18
|
+
from msprobe.core.common.utils import is_invalid_pattern
|
|
19
|
+
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
20
|
+
from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
|
|
5
21
|
from msprobe.mindspore.common.log import logger
|
|
6
22
|
|
|
23
|
+
|
|
7
24
|
class ApiInfo:
|
|
8
25
|
def __init__(self, api_name):
|
|
26
|
+
if not isinstance(api_name, str):
|
|
27
|
+
err_msg = "ApiInfo.__init__ failed: api_name is not a string"
|
|
28
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
|
|
29
|
+
if is_invalid_pattern(api_name):
|
|
30
|
+
err_msg = "ApiInfo.__init__ failed: api_name contain illegal character"
|
|
31
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
|
|
9
32
|
self.api_name = api_name
|
|
10
33
|
self.forward_info = None
|
|
11
34
|
self.backward_info = None
|
|
@@ -59,11 +82,10 @@ class ApiInfo:
|
|
|
59
82
|
err_msg = "ApiInfo.get_kwargs failed: compute_element_dict key is not a string"
|
|
60
83
|
logger.error_log_with_exp(err_msg,
|
|
61
84
|
ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
|
|
62
|
-
if not isinstance(compute_element_info, (list, dict)):
|
|
63
|
-
err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list or
|
|
85
|
+
if not (isinstance(compute_element_info, (list, dict)) or compute_element_info is None):
|
|
86
|
+
err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list, dict or null"
|
|
64
87
|
logger.error_log_with_exp(err_msg,
|
|
65
88
|
ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
|
|
66
89
|
kwargs_compute_element_dict = {key_str: ComputeElement(compute_element_info=compute_element_info)
|
|
67
90
|
for key_str, compute_element_info in kwargs_dict.items()}
|
|
68
91
|
return kwargs_compute_element_dict
|
|
69
|
-
|