mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
# Ascend模型梯度状态监测工具
|
|
2
|
+
|
|
3
|
+
梯度状态监测工具提供了两种能力:
|
|
4
|
+
|
|
5
|
+
- 将模型权重的梯度数据导出。这种功能可以将模型权重的梯度值以统计量的形式采集出来,用以分析问题。
|
|
6
|
+
- 将两份梯度数据进行相似度对比。在有标杆问题中,可以确认训练过程中精度问题出现的step,以及抓取反向过程中的问题。
|
|
7
|
+
|
|
8
|
+
工具支持PyTorch版本:2.0/2.1/2.2;支持MindSpore版本:r2.3。
|
|
9
|
+
|
|
10
|
+
## 工具特性
|
|
11
|
+
|
|
12
|
+
- 使用便捷,无需在训练流程里插入代码
|
|
13
|
+
- 可以精准定位问题出现的step
|
|
14
|
+
|
|
15
|
+
## 使用方式
|
|
16
|
+
|
|
17
|
+
### 梯度数据导出
|
|
18
|
+
|
|
19
|
+
1. 创建配置文件config.json,样例如下:
|
|
20
|
+
|
|
21
|
+
```json
|
|
22
|
+
{
|
|
23
|
+
"task": "grad_probe",
|
|
24
|
+
"dump_path": "./dump_path",
|
|
25
|
+
"rank": [],
|
|
26
|
+
"step": [],
|
|
27
|
+
"grad_probe": {
|
|
28
|
+
"grad_level": "L1",
|
|
29
|
+
"param_list": [],
|
|
30
|
+
"bounds": [-1, 0, 1]
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
```
|
|
34
|
+
> step指的是优化器被调用的次数(并非模型跑的step,某些step,例如loss为nan时,不会调用优化器)
|
|
35
|
+
|
|
36
|
+
**参数说明**
|
|
37
|
+
|
|
38
|
+
| 参数 | 说明 | 输入类型 | 是否必选 |
|
|
39
|
+
|--------------------------------|-----------------------------------|-----------------|----------|
|
|
40
|
+
| task | 填为"grad_probe"。 | str | 是 |
|
|
41
|
+
| grad_level | 输出级别。决定导出数据的详细程度,级别越大导出数据越详细。可取值:L0, L1, L2|str | 是 |
|
|
42
|
+
| param_list | 权重名称列表,表示需要监控的权重。列表为空就表示监控所有权重。 | List[str] | 是 |
|
|
43
|
+
| rank | rank id列表,在多卡场景下,表示需要导出梯度数据的进程的rank id。列表为空就表示导出所有rank的数据。(MindSpore静态图模式下,当前暂不支持指定rank功能) | List[int] | 是 |
|
|
44
|
+
| step | step列表,表示需要导出数据的step列表。列表为空就表示导出所有step的数据。(MindSpore静态图模式下,当前暂不支持指定step功能) | List[int] | 是 |
|
|
45
|
+
| bounds | 区间列表,用来划分区间以统计数值的分布。需要保证由数据小到大排列。可以使用默认值[-1, 0, 1] | List[float] | 是 |
|
|
46
|
+
| dump_path | 输出目录。如果不存在就会创建一个新目录。 | str | 是 |
|
|
47
|
+
|
|
48
|
+
**不同级别的level的导出数据**
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
| 级别 | 特征数据表头 | 是否有方向数据 |
|
|
52
|
+
| ---- | ------------------------------------------------------------ | -------------- |
|
|
53
|
+
| L0 | ("param_name", "MD5", "max", "min", "norm", "shape") | 否 |
|
|
54
|
+
| L1 | ("param_name", "max", "min", "norm", "shape") | 是 |
|
|
55
|
+
| L2 | ("param_name", *intervals, "=0", "max", "min", "norm", "shape") | 是 |
|
|
56
|
+
|
|
57
|
+
intervals就是根据值分布bounds划分出的区间。
|
|
58
|
+
MindSpore静态图模式下,L0级别中暂不支持"MD5"
|
|
59
|
+
|
|
60
|
+
**方向数据解释**
|
|
61
|
+
|
|
62
|
+
因为模型的参数往往非常大,所以存储真实数据是不可接受的,这里折衷一下,只存储梯度数据的正负号(一个布尔值),也就是方向。
|
|
63
|
+
|
|
64
|
+
**bounds和值分布解释**
|
|
65
|
+
|
|
66
|
+
+ 值分布:梯度数据落在各个区间的元素个数占总元素个数的比例。
|
|
67
|
+
+ bounds:一个列表,用来划分出区间以统计值分布。例如传入bounds = [-10, 0, 10],此时有一个 grad_value: Tensor = [9.3 , 5.4, -1.0, -12.3],依据 bounds 划分出 (-inf, -10]、(-10, 0]、(0, 10]、(10, inf) 四个区间,然后统计grad_value里的数据落在每个区间内的个数,得到 1、1、2、0。如下图所示:
|
|
68
|
+

|
|
69
|
+
|
|
70
|
+
2. 插入代码。示例代码如下:
|
|
71
|
+
|
|
72
|
+
- PyTorch框架:模型构造完成后,传入config.json的路径实例化一个GradientMonitor对象,然后调用gm.monitor并将`模型`作为参数传入。
|
|
73
|
+
```python
|
|
74
|
+
from msprobe.pytorch import PrecisionDebugger
|
|
75
|
+
debugger = PrecisionDebugger("config_json_path")
|
|
76
|
+
debugger.monitor(model)
|
|
77
|
+
```
|
|
78
|
+
- MindSpore框架:优化器构造完成后,传入config.json的路径实例化一个GradientMonitor对象,然后调用gm.monitor并将`优化器`作为参数传入。
|
|
79
|
+
```python
|
|
80
|
+
from msprobe.mindspore import PrecisionDebugger
|
|
81
|
+
debugger = PrecisionDebugger("config_json_path")
|
|
82
|
+
debugger.monitor(optimizer)
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
3. 结束监控(MindSpore静态图模式下需要)
|
|
86
|
+
|
|
87
|
+
在训练结束之后,调用stop接口
|
|
88
|
+
|
|
89
|
+
```python
|
|
90
|
+
gm.stop()
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
### 输出结果
|
|
94
|
+
**输出目录结构**(以level配置L2为例)
|
|
95
|
+
|
|
96
|
+
```bash
|
|
97
|
+
{dump_path}
|
|
98
|
+
├── rank{rank_id}
|
|
99
|
+
│ ├── grad_summary_{step}.csv
|
|
100
|
+
│ ├── step{step}
|
|
101
|
+
│ │ ├── {param_name}.npy
|
|
102
|
+
```
|
|
103
|
+
+ {timestamp}:梯度工具导出数据的时候会在output_path下生成一个时间戳目录,然后在这个时间戳目录下输出结果。
|
|
104
|
+
+ rank_{rank_id}:在分布式场景下,会记录卡的rank_id。非分布式场景下,如果是CPU则记录进程号,如果是CPU或GPU则记录卡号
|
|
105
|
+
+ grad_summary_{step}.csv:会分step记录每一步的梯度数据统计值。
|
|
106
|
+
+ step_{step}:这个目录下会存放该step的梯度的方向数据。
|
|
107
|
+
+ {param_name}.pt(npy):模型参数的梯度方向数据,PyTorch保存的是pt文件,MindSpore是npy文件。
|
|
108
|
+
|
|
109
|
+
**grad_summary_{step}.csv**
|
|
110
|
+
|
|
111
|
+
样例如下:
|
|
112
|
+
|
|
113
|
+

|
|
114
|
+
|
|
115
|
+
| 字段 | 含义 |
|
|
116
|
+
| --------------------- | ------------------------------------------------------------|
|
|
117
|
+
| Param_name | 模型参数名称。 |
|
|
118
|
+
| MD5 | 梯度数据的MD5值。 |
|
|
119
|
+
| (-inf, -0.01]...[0.01, inf) | 梯度值落在区间内的元素个数占总元素的比例。 |
|
|
120
|
+
| =0 | 梯度为0的元素个数占总元素的比例。 |
|
|
121
|
+
| Max | 最大值。 |
|
|
122
|
+
| Min | 最小值。 |
|
|
123
|
+
| Norm | L2norm值。 |
|
|
124
|
+
| Shape | 形状。 |
|
|
125
|
+
|
|
126
|
+
### 梯度相似度比对
|
|
127
|
+
|
|
128
|
+
会根据所导出的权重,分step比对梯度相似度,输出每个权重的梯度相似度和总的梯度相似度。单个权重的梯度相似度为两份方向数据的重合度,总的梯度相似度为每个权重的梯度相似度按元素个数加权。
|
|
129
|
+
|
|
130
|
+
#### 前提条件
|
|
131
|
+
|
|
132
|
+
- 相同配置下,以Level为L1或L2分别采集npu和gpu环境下的梯度数据。
|
|
133
|
+
- 将两份梯度数据传到同一环境下。
|
|
134
|
+
|
|
135
|
+
#### 使用方式
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
新建如下Python脚本,传入npu和gpu的dump_path以及输出目录,比对结果输出目录不存在的话会新建:
|
|
139
|
+
|
|
140
|
+
```python
|
|
141
|
+
from msprobe import *
|
|
142
|
+
GradComparator.compare_distributed("配置文件里写的dump_path",
|
|
143
|
+
"配置文件里写的dump_path",
|
|
144
|
+
"比对结果输出目录")
|
|
145
|
+
```
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
### 比对结果
|
|
149
|
+
|
|
150
|
+
**输出目录结构**
|
|
151
|
+
|
|
152
|
+
如下为多卡比对结果,单卡则没有rank_{rank_id}这一级目录。
|
|
153
|
+
|
|
154
|
+
```bash
|
|
155
|
+
比对结果输出目录
|
|
156
|
+
├── rank{rank_id}
|
|
157
|
+
│ ├── similarities.csv
|
|
158
|
+
│ └── similarities_picture
|
|
159
|
+
│ ├── {param_name}.png
|
|
160
|
+
│ └── summary_similarities.png
|
|
161
|
+
```
|
|
162
|
+
|
|
163
|
+
**问题界定**
|
|
164
|
+
|
|
165
|
+
原则:对于任意权重,第0步的梯度相似度低于0.97,或者某一步的梯度相似度下降超过0.03,认为这一步存在精度问题。例子如下:
|
|
166
|
+
|
|
167
|
+
- 第0步相似度低于0.97
|
|
168
|
+
|
|
169
|
+

|
|
170
|
+
|
|
171
|
+
- 第3步相似度下降超过0.03
|
|
172
|
+
|
|
173
|
+

|
|
174
|
+
|
|
175
|
+
- 正常情况
|
|
176
|
+
|
|
177
|
+

|
|
178
|
+
|
|
179
|
+
这个原则是一个经验性的指标,并不是严格的标注,还需要结合实际情况具体分析。
|
|
180
|
+
|
|
181
|
+
## 公开接口
|
|
182
|
+
|
|
183
|
+
**接口说明**
|
|
184
|
+
|
|
185
|
+
```python
|
|
186
|
+
PrecisionDebugger.monitor(module)
|
|
187
|
+
```
|
|
188
|
+
|
|
189
|
+
| 参数 | 说明 | 是否必选 |
|
|
190
|
+
| ----- | -------------------- | -------- |
|
|
191
|
+
| module |Pytorch框架下传入模型,必须是torch.nn.Module;MindSpore框架下传入优化器。 | 是 |
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
**接口说明**
|
|
195
|
+
|
|
196
|
+
```python
|
|
197
|
+
GradComparator.compare_distributed(dump_path1, dump_path2, output_path)
|
|
198
|
+
```
|
|
199
|
+
|
|
200
|
+
| 参数 | 说明 | 是否必选 |
|
|
201
|
+
| ----- | -------------------- | -------- |
|
|
202
|
+
| dump_path1 |需要比对的其中一个dump目录,也就是配置文件里写的dump_path。 | 是 |
|
|
203
|
+
| dump_path2 |需要比对的其中一个dump目录,也就是配置文件里写的dump_path,与dump_path1可以互换。 | 是 |
|
|
204
|
+
| output_path |输出结果目录,不存在会新建。 | 是 |
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# FAQ
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
File without changes
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from msprobe.core.common.file_check import FileOpen
|
|
5
|
+
from msprobe.core.common.utils import write_csv, add_time_as_suffix
|
|
6
|
+
from msprobe.core.common.const import Const, CompareConst, MsCompareConst
|
|
7
|
+
from msprobe.core.common.log import logger
|
|
8
|
+
from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo
|
|
9
|
+
from msprobe.mindspore.api_accuracy_checker.api_runner import api_runner, ApiInputAggregation
|
|
10
|
+
from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
|
|
11
|
+
from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BasicInfoAndStatus:
|
|
15
|
+
def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
|
|
16
|
+
self.api_name = api_name
|
|
17
|
+
self.bench_dtype = bench_dtype
|
|
18
|
+
self.tested_dtype = tested_dtype
|
|
19
|
+
self.shape = shape
|
|
20
|
+
self.status = status
|
|
21
|
+
self.err_msg = err_msg
|
|
22
|
+
|
|
23
|
+
class ResultCsvEntry:
|
|
24
|
+
def __init__(self) -> None:
|
|
25
|
+
self.forward_pass_status = None
|
|
26
|
+
self.backward_pass_status = None
|
|
27
|
+
self.forward_err_msg = ""
|
|
28
|
+
self.backward_err_msg = ""
|
|
29
|
+
self.overall_err_msg = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ApiAccuracyChecker:
|
|
33
|
+
def __init__(self):
|
|
34
|
+
self.api_infos = dict()
|
|
35
|
+
self.results = dict()
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
|
|
39
|
+
'''
|
|
40
|
+
Args:
|
|
41
|
+
api_info: ApiInfo
|
|
42
|
+
api_name_str: str
|
|
43
|
+
api_input_aggregation: ApiInputAggregation
|
|
44
|
+
forward_or_backward: str: Union["forward", "backward"]
|
|
45
|
+
|
|
46
|
+
Return:
|
|
47
|
+
output_list: List[tuple(str, str, BasicInfoAndStatus, dict{str: CompareResult})]
|
|
48
|
+
|
|
49
|
+
Description:
|
|
50
|
+
get mindspore api output, run torch api and get output.
|
|
51
|
+
compare output.
|
|
52
|
+
record compare result.
|
|
53
|
+
'''
|
|
54
|
+
# get output
|
|
55
|
+
if global_context.get_is_constructed():
|
|
56
|
+
# constructed situation, need use constructed input to run mindspore api getting tested_output
|
|
57
|
+
tested_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.MS_FRAMEWORK)
|
|
58
|
+
else:
|
|
59
|
+
tested_outputs = api_info.get_compute_element_list(forward_or_backward, Const.OUTPUT)
|
|
60
|
+
bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK)
|
|
61
|
+
|
|
62
|
+
# compare output
|
|
63
|
+
output_list = []
|
|
64
|
+
for i, (bench_out, tested_out) in enumerate(zip(bench_outputs, tested_outputs)):
|
|
65
|
+
api_name_with_slot = Const.SEP.join([api_name_str, forward_or_backward, Const.OUTPUT, str(i)])
|
|
66
|
+
bench_dtype = bench_out.get_dtype()
|
|
67
|
+
tested_dtype = tested_out.get_dtype()
|
|
68
|
+
shape = bench_out.get_shape()
|
|
69
|
+
|
|
70
|
+
compare_result_dict = dict()
|
|
71
|
+
for compare_algorithm_name, compare_algorithm in compare_algorithms.items():
|
|
72
|
+
compare_result = compare_algorithm(bench_out, tested_out)
|
|
73
|
+
compare_result_dict[compare_algorithm_name] = compare_result
|
|
74
|
+
|
|
75
|
+
if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \
|
|
76
|
+
compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
|
|
77
|
+
status = CompareConst.PASS
|
|
78
|
+
err_msg = ""
|
|
79
|
+
else:
|
|
80
|
+
status = CompareConst.ERROR
|
|
81
|
+
err_msg = compare_result_dict.get(CompareConst.COSINE).err_msg + \
|
|
82
|
+
compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg
|
|
83
|
+
basic_info_status = \
|
|
84
|
+
BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
|
|
85
|
+
output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
|
|
86
|
+
return output_list
|
|
87
|
+
|
|
88
|
+
def parse(self, api_info_path):
|
|
89
|
+
with FileOpen(api_info_path, "r") as f:
|
|
90
|
+
api_info_dict = json.load(f)
|
|
91
|
+
|
|
92
|
+
# init global context
|
|
93
|
+
task = check_and_get_from_json_dict(api_info_dict, MsCompareConst.TASK_FIELD,
|
|
94
|
+
"task field in api_info.json",accepted_type=str,
|
|
95
|
+
accepted_value=(MsCompareConst.STATISTICS_TASK,
|
|
96
|
+
MsCompareConst.TENSOR_TASK))
|
|
97
|
+
is_constructed = task == MsCompareConst.STATISTICS_TASK
|
|
98
|
+
if not is_constructed:
|
|
99
|
+
dump_data_dir = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DUMP_DATA_DIR_FIELD,
|
|
100
|
+
"dump_data_dir field in api_info.json", accepted_type=str)
|
|
101
|
+
else:
|
|
102
|
+
dump_data_dir = ""
|
|
103
|
+
global_context.init(is_constructed, dump_data_dir)
|
|
104
|
+
|
|
105
|
+
api_info_data = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DATA_FIELD,
|
|
106
|
+
"data field in api_info.json", accepted_type=dict)
|
|
107
|
+
for api_name, api_info in api_info_data.items():
|
|
108
|
+
is_mint = api_name.split(Const.SEP)[0] in \
|
|
109
|
+
(MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL)
|
|
110
|
+
if not is_mint:
|
|
111
|
+
continue
|
|
112
|
+
forbackward_str = api_name.split(Const.SEP)[-1]
|
|
113
|
+
if forbackward_str not in (Const.FORWARD, Const.BACKWARD):
|
|
114
|
+
logger.warning(f"api: {api_name} is not recognized as forward api or backward api, skip this.")
|
|
115
|
+
api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1]) # www.xxx.yyy.zzz --> www.xxx.yyy
|
|
116
|
+
if api_name not in self.api_infos:
|
|
117
|
+
self.api_infos[api_name] = ApiInfo(api_name)
|
|
118
|
+
|
|
119
|
+
if forbackward_str == Const.FORWARD:
|
|
120
|
+
self.api_infos[api_name].load_forward_info(api_info)
|
|
121
|
+
else:
|
|
122
|
+
self.api_infos[api_name].load_backward_info(api_info)
|
|
123
|
+
|
|
124
|
+
def run_and_compare(self):
|
|
125
|
+
for api_name_str, api_info in self.api_infos.items():
|
|
126
|
+
if not api_info.check_forward_info():
|
|
127
|
+
logger.warning(f"api: {api_name_str} is lack of forward infomation, skip forward and backward check")
|
|
128
|
+
continue
|
|
129
|
+
forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
|
|
130
|
+
kwargs = api_info.get_kwargs()
|
|
131
|
+
forward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, None)
|
|
132
|
+
forward_output_list = None
|
|
133
|
+
try:
|
|
134
|
+
forward_output_list = \
|
|
135
|
+
self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
|
|
136
|
+
except Exception as e:
|
|
137
|
+
logger.warning(f"exception occurs when running and comparing {api_name_str} forward api"
|
|
138
|
+
f"detailed exception information: {e}")
|
|
139
|
+
self.record(forward_output_list)
|
|
140
|
+
|
|
141
|
+
if not api_info.check_backward_info():
|
|
142
|
+
logger.warning(f"api: {api_name_str} is lack of backward infomation, skip backward check")
|
|
143
|
+
continue
|
|
144
|
+
gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
|
|
145
|
+
backward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
|
|
146
|
+
backward_output_list = None
|
|
147
|
+
try:
|
|
148
|
+
backward_output_list = \
|
|
149
|
+
self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
|
|
150
|
+
except Exception as e:
|
|
151
|
+
logger.warning(f"exception occurs when running and comparing {api_name_str} backward api"
|
|
152
|
+
f"detailed exception information: {e}")
|
|
153
|
+
self.record(backward_output_list)
|
|
154
|
+
|
|
155
|
+
def record(self, output_list):
|
|
156
|
+
if output_list is None:
|
|
157
|
+
return
|
|
158
|
+
for output in output_list:
|
|
159
|
+
api_real_name, forward_or_backward, basic_info, compare_result_dict = output
|
|
160
|
+
key = tuple([api_real_name, forward_or_backward])
|
|
161
|
+
if key not in self.results:
|
|
162
|
+
self.results[key] = []
|
|
163
|
+
self.results[key].append(tuple([basic_info, compare_result_dict]))
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def to_detail_csv(self, csv_dir):
|
|
167
|
+
# detail_csv
|
|
168
|
+
detail_csv = []
|
|
169
|
+
detail_csv_header_basic_info = [
|
|
170
|
+
MsCompareConst.DETAIL_CSV_API_NAME,
|
|
171
|
+
MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
|
|
172
|
+
MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
|
|
173
|
+
MsCompareConst.DETAIL_CSV_SHAPE,
|
|
174
|
+
]
|
|
175
|
+
detail_csv_header_compare_result = list(compare_algorithms.keys())
|
|
176
|
+
detail_csv_header_status = [
|
|
177
|
+
MsCompareConst.DETAIL_CSV_PASS_STATUS,
|
|
178
|
+
MsCompareConst.DETAIL_CSV_MESSAGE,
|
|
179
|
+
]
|
|
180
|
+
|
|
181
|
+
detail_csv_header = detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
|
|
182
|
+
detail_csv.append(detail_csv_header)
|
|
183
|
+
|
|
184
|
+
for _, results in self.results.items():
|
|
185
|
+
# detail csv
|
|
186
|
+
for res in results:
|
|
187
|
+
basic_info, compare_result_dict = res
|
|
188
|
+
csv_row_basic_info = \
|
|
189
|
+
[basic_info.api_name, basic_info.bench_dtype, basic_info.tested_dtype, basic_info.shape]
|
|
190
|
+
csv_row_compare_result = list(compare_result_dict.get(algorithm_name).compare_value \
|
|
191
|
+
for algorithm_name in detail_csv_header_compare_result)
|
|
192
|
+
csv_row_status = [basic_info.status, basic_info.err_msg]
|
|
193
|
+
csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
|
|
194
|
+
detail_csv.append(csv_row)
|
|
195
|
+
|
|
196
|
+
file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.DETAIL_CSV_FILE_NAME))
|
|
197
|
+
write_csv(detail_csv, file_name, mode="w")
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def to_result_csv(self, csv_dir):
|
|
201
|
+
result_csv_dict = dict()
|
|
202
|
+
for key, results in self.results.items():
|
|
203
|
+
api_real_name, forward_or_backward = key
|
|
204
|
+
forward_or_backward_pass_status = CompareConst.PASS
|
|
205
|
+
forward_or_backward_overall_err_msg = ""
|
|
206
|
+
# detail csv
|
|
207
|
+
for res in results:
|
|
208
|
+
basic_info, _ = res
|
|
209
|
+
if basic_info.status != CompareConst.PASS:
|
|
210
|
+
forward_or_backward_pass_status = CompareConst.ERROR
|
|
211
|
+
forward_or_backward_overall_err_msg += basic_info.err_msg
|
|
212
|
+
forward_or_backward_overall_err_msg = \
|
|
213
|
+
"" if forward_or_backward_pass_status == CompareConst.PASS else forward_or_backward_overall_err_msg
|
|
214
|
+
|
|
215
|
+
#result_csv_dict
|
|
216
|
+
if api_real_name not in result_csv_dict:
|
|
217
|
+
result_csv_dict[api_real_name] = ResultCsvEntry()
|
|
218
|
+
if forward_or_backward == Const.FORWARD:
|
|
219
|
+
result_csv_dict[api_real_name].forward_pass_status = forward_or_backward_pass_status
|
|
220
|
+
result_csv_dict[api_real_name].forward_err_msg = forward_or_backward_overall_err_msg
|
|
221
|
+
else:
|
|
222
|
+
result_csv_dict[api_real_name].backward_pass_status = forward_or_backward_pass_status
|
|
223
|
+
result_csv_dict[api_real_name].backward_err_msg = forward_or_backward_overall_err_msg
|
|
224
|
+
|
|
225
|
+
#result_csv
|
|
226
|
+
result_csv = []
|
|
227
|
+
result_csv_header = [
|
|
228
|
+
MsCompareConst.DETAIL_CSV_API_NAME,
|
|
229
|
+
MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
|
|
230
|
+
MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
|
|
231
|
+
MsCompareConst.DETAIL_CSV_MESSAGE,
|
|
232
|
+
]
|
|
233
|
+
result_csv.append(result_csv_header)
|
|
234
|
+
|
|
235
|
+
for api_name, result_csv_entry in result_csv_dict.items():
|
|
236
|
+
if result_csv_entry.forward_pass_status == CompareConst.PASS and \
|
|
237
|
+
result_csv_entry.backward_pass_status == CompareConst.PASS:
|
|
238
|
+
overall_err_msg = ""
|
|
239
|
+
else:
|
|
240
|
+
overall_err_msg = result_csv_entry.forward_err_msg + result_csv_entry.backward_err_msg
|
|
241
|
+
row = [api_name, result_csv_entry.forward_pass_status,
|
|
242
|
+
result_csv_entry.backward_pass_status, overall_err_msg]
|
|
243
|
+
result_csv.append(row)
|
|
244
|
+
|
|
245
|
+
file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
|
|
246
|
+
write_csv(result_csv, file_name, mode="w")
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
2
|
+
from msprobe.core.common.const import Const
|
|
3
|
+
from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
|
|
4
|
+
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
5
|
+
from msprobe.core.common.log import logger
|
|
6
|
+
|
|
7
|
+
class ApiInfo:
|
|
8
|
+
def __init__(self, api_name):
|
|
9
|
+
self.api_name = api_name
|
|
10
|
+
self.forward_info = None
|
|
11
|
+
self.backward_info = None
|
|
12
|
+
|
|
13
|
+
def load_forward_info(self, forward_info_dict):
|
|
14
|
+
self.forward_info = forward_info_dict
|
|
15
|
+
|
|
16
|
+
def load_backward_info(self, backward_info_dict):
|
|
17
|
+
self.backward_info = backward_info_dict
|
|
18
|
+
|
|
19
|
+
def check_forward_info(self):
|
|
20
|
+
return self.forward_info is not None
|
|
21
|
+
|
|
22
|
+
def check_backward_info(self):
|
|
23
|
+
return self.backward_info is not None
|
|
24
|
+
|
|
25
|
+
def get_compute_element_list(self, forward_or_backward, input_or_output):
|
|
26
|
+
'''
|
|
27
|
+
Args:
|
|
28
|
+
forward_or_backward: str, Union["forward", "backward"]
|
|
29
|
+
input_or_output: str, Union["input", "output"]
|
|
30
|
+
|
|
31
|
+
Return:
|
|
32
|
+
compute_element_list: List[ComputeElement]
|
|
33
|
+
'''
|
|
34
|
+
mapping = {
|
|
35
|
+
(Const.FORWARD, Const.INPUT): [self.forward_info, Const.INPUT_ARGS,
|
|
36
|
+
f"input_args field of {self.api_name} forward api in api_info.json"],
|
|
37
|
+
(Const.FORWARD, Const.OUTPUT): [self.forward_info, Const.OUTPUT,
|
|
38
|
+
f"output field of {self.api_name} forward api in api_info.json"],
|
|
39
|
+
(Const.BACKWARD, Const.INPUT): [self.backward_info, Const.INPUT,
|
|
40
|
+
f"input field of {self.api_name} backward api in api_info.json"],
|
|
41
|
+
(Const.BACKWARD, Const.OUTPUT): [self.backward_info, Const.OUTPUT,
|
|
42
|
+
f"output field of {self.api_name} backward api in api_info.json"]
|
|
43
|
+
}
|
|
44
|
+
dict_instance, key, key_desc = mapping.get((forward_or_backward, input_or_output))
|
|
45
|
+
compute_element_info_list = check_and_get_from_json_dict(dict_instance, key, key_desc, accepted_type=list)
|
|
46
|
+
compute_element_list = [ComputeElement(compute_element_info=compute_element_info)
|
|
47
|
+
for compute_element_info in compute_element_info_list]
|
|
48
|
+
return compute_element_list
|
|
49
|
+
|
|
50
|
+
def get_kwargs(self):
|
|
51
|
+
'''
|
|
52
|
+
Return:
|
|
53
|
+
kwargs_compute_element_dict: dict{str: ComputeElement}
|
|
54
|
+
'''
|
|
55
|
+
kwargs_dict = check_and_get_from_json_dict(self.forward_info, Const.INPUT_KWARGS,
|
|
56
|
+
"input_kwargs in api_info.json", accepted_type=dict)
|
|
57
|
+
for key_str, compute_element_info in kwargs_dict.items():
|
|
58
|
+
if not isinstance(key_str, str):
|
|
59
|
+
err_msg = "ApiInfo.get_kwargs failed: compute_element_dict key is not a string"
|
|
60
|
+
logger.error_log_with_exp(err_msg,
|
|
61
|
+
ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
|
|
62
|
+
if not isinstance(compute_element_info, (list, dict)):
|
|
63
|
+
err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list or dict"
|
|
64
|
+
logger.error_log_with_exp(err_msg,
|
|
65
|
+
ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
|
|
66
|
+
kwargs_compute_element_dict = {key_str: ComputeElement(compute_element_info=compute_element_info)
|
|
67
|
+
for key_str, compute_element_info in kwargs_dict.items()}
|
|
68
|
+
return kwargs_compute_element_dict
|
|
69
|
+
|