mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +131 -237
- msprobe/__init__.py +16 -1
- msprobe/{config/config.json → config.json} +47 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +58 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +402 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +523 -283
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +86 -69
- msprobe/core/common/utils.py +371 -616
- msprobe/core/common_config.py +78 -71
- msprobe/core/compare/acc_compare.py +472 -298
- msprobe/core/compare/check.py +180 -95
- msprobe/core/compare/compare_cli.py +69 -49
- msprobe/core/compare/highlight.py +259 -222
- msprobe/core/compare/multiprocessing_compute.py +174 -149
- msprobe/core/compare/npy_compare.py +310 -295
- msprobe/core/compare/utils.py +464 -429
- msprobe/core/data_dump/data_collector.py +153 -144
- msprobe/core/data_dump/data_processor/base.py +337 -293
- msprobe/core/data_dump/data_processor/factory.py +76 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
- msprobe/core/data_dump/json_writer.py +117 -116
- msprobe/core/data_dump/scope.py +194 -178
- msprobe/core/grad_probe/constant.py +74 -70
- msprobe/core/grad_probe/grad_compare.py +170 -175
- msprobe/core/grad_probe/utils.py +77 -52
- msprobe/docs/01.installation.md +99 -0
- msprobe/docs/02.config_introduction.md +137 -0
- msprobe/docs/03.config_examples.md +237 -0
- msprobe/docs/04.acl_config_examples.md +78 -0
- msprobe/docs/05.data_dump_PyTorch.md +326 -0
- msprobe/docs/06.data_dump_MindSpore.md +285 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
- msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
- msprobe/docs/FAQ.md +189 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +2 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +58 -34
- msprobe/mindspore/common/const.py +108 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +97 -57
- msprobe/mindspore/compare/distributed_compare.py +62 -75
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +357 -117
- msprobe/mindspore/compare/ms_graph_compare.py +364 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +69 -74
- msprobe/mindspore/debugger/precision_debugger.py +150 -107
- msprobe/mindspore/dump/dump_tool_factory.py +50 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
- msprobe/mindspore/dump/jit_dump.py +96 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
- msprobe/mindspore/free_benchmark/common/config.py +27 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
- msprobe/mindspore/free_benchmark/common/utils.py +85 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
- msprobe/mindspore/grad_probe/global_context.py +100 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +297 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +23 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +30 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
- msprobe/pytorch/bench_functions/linear.py +27 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
- msprobe/pytorch/bench_functions/rms_norm.py +30 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
- msprobe/pytorch/bench_functions/swiglu.py +70 -55
- msprobe/pytorch/common/__init__.py +17 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +33 -32
- msprobe/pytorch/common/parse_json.py +54 -39
- msprobe/pytorch/common/utils.py +310 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +49 -33
- msprobe/pytorch/compare/pt_compare.py +82 -40
- msprobe/pytorch/debugger/debugger_config.py +108 -95
- msprobe/pytorch/debugger/precision_debugger.py +173 -125
- msprobe/pytorch/free_benchmark/__init__.py +23 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +65 -37
- msprobe/pytorch/free_benchmark/common/params.py +144 -129
- msprobe/pytorch/free_benchmark/common/utils.py +118 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
- msprobe/pytorch/free_benchmark/main.py +120 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
- msprobe/pytorch/function_factory.py +91 -75
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +166 -161
- msprobe/pytorch/hook_module/hook_module.py +118 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +28 -29
- msprobe/pytorch/hook_module/wrap_aten.py +111 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
- msprobe/pytorch/hook_module/wrap_functional.py +104 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
- msprobe/pytorch/hook_module/wrap_torch.py +84 -86
- msprobe/pytorch/hook_module/wrap_vf.py +60 -62
- msprobe/pytorch/module_processer.py +153 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +235 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
- msprobe/pytorch/online_dispatch/utils.py +127 -146
- msprobe/pytorch/parse.py +19 -4
- msprobe/pytorch/parse_tool/cli.py +31 -32
- msprobe/pytorch/parse_tool/lib/compare.py +259 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
- msprobe/pytorch/parse_tool/lib/utils.py +320 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +317 -187
- msprobe/pytorch/service.py +311 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -1,69 +1,76 @@
|
|
|
1
|
-
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
2
|
-
from msprobe.core.common.const import Const
|
|
3
|
-
from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
|
|
4
|
-
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
5
|
-
from msprobe.
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
self.
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
'''
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
1
|
+
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
2
|
+
from msprobe.core.common.const import Const
|
|
3
|
+
from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
|
|
4
|
+
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
5
|
+
from msprobe.mindspore.common.log import logger
|
|
6
|
+
from msprobe.core.common.utils import is_invalid_pattern
|
|
7
|
+
|
|
8
|
+
class ApiInfo:
|
|
9
|
+
def __init__(self, api_name):
|
|
10
|
+
if not isinstance(api_name, str):
|
|
11
|
+
err_msg = "ApiInfo.__init__ failed: api_name is not a string"
|
|
12
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
|
|
13
|
+
if is_invalid_pattern(api_name):
|
|
14
|
+
err_msg = "ApiInfo.__init__ failed: api_name contain illegal character"
|
|
15
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
|
|
16
|
+
self.api_name = api_name
|
|
17
|
+
self.forward_info = None
|
|
18
|
+
self.backward_info = None
|
|
19
|
+
|
|
20
|
+
def load_forward_info(self, forward_info_dict):
|
|
21
|
+
self.forward_info = forward_info_dict
|
|
22
|
+
|
|
23
|
+
def load_backward_info(self, backward_info_dict):
|
|
24
|
+
self.backward_info = backward_info_dict
|
|
25
|
+
|
|
26
|
+
def check_forward_info(self):
|
|
27
|
+
return self.forward_info is not None
|
|
28
|
+
|
|
29
|
+
def check_backward_info(self):
|
|
30
|
+
return self.backward_info is not None
|
|
31
|
+
|
|
32
|
+
def get_compute_element_list(self, forward_or_backward, input_or_output):
|
|
33
|
+
'''
|
|
34
|
+
Args:
|
|
35
|
+
forward_or_backward: str, Union["forward", "backward"]
|
|
36
|
+
input_or_output: str, Union["input", "output"]
|
|
37
|
+
|
|
38
|
+
Return:
|
|
39
|
+
compute_element_list: List[ComputeElement]
|
|
40
|
+
'''
|
|
41
|
+
mapping = {
|
|
42
|
+
(Const.FORWARD, Const.INPUT): [self.forward_info, Const.INPUT_ARGS,
|
|
43
|
+
f"input_args field of {self.api_name} forward api in api_info.json"],
|
|
44
|
+
(Const.FORWARD, Const.OUTPUT): [self.forward_info, Const.OUTPUT,
|
|
45
|
+
f"output field of {self.api_name} forward api in api_info.json"],
|
|
46
|
+
(Const.BACKWARD, Const.INPUT): [self.backward_info, Const.INPUT,
|
|
47
|
+
f"input field of {self.api_name} backward api in api_info.json"],
|
|
48
|
+
(Const.BACKWARD, Const.OUTPUT): [self.backward_info, Const.OUTPUT,
|
|
49
|
+
f"output field of {self.api_name} backward api in api_info.json"]
|
|
50
|
+
}
|
|
51
|
+
dict_instance, key, key_desc = mapping.get((forward_or_backward, input_or_output))
|
|
52
|
+
compute_element_info_list = check_and_get_from_json_dict(dict_instance, key, key_desc, accepted_type=list)
|
|
53
|
+
compute_element_list = [ComputeElement(compute_element_info=compute_element_info)
|
|
54
|
+
for compute_element_info in compute_element_info_list]
|
|
55
|
+
return compute_element_list
|
|
56
|
+
|
|
57
|
+
def get_kwargs(self):
|
|
58
|
+
'''
|
|
59
|
+
Return:
|
|
60
|
+
kwargs_compute_element_dict: dict{str: ComputeElement}
|
|
61
|
+
'''
|
|
62
|
+
kwargs_dict = check_and_get_from_json_dict(self.forward_info, Const.INPUT_KWARGS,
|
|
63
|
+
"input_kwargs in api_info.json", accepted_type=dict)
|
|
64
|
+
for key_str, compute_element_info in kwargs_dict.items():
|
|
65
|
+
if not isinstance(key_str, str):
|
|
66
|
+
err_msg = "ApiInfo.get_kwargs failed: compute_element_dict key is not a string"
|
|
67
|
+
logger.error_log_with_exp(err_msg,
|
|
68
|
+
ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
|
|
69
|
+
if not isinstance(compute_element_info, (list, dict)):
|
|
70
|
+
err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list or dict"
|
|
71
|
+
logger.error_log_with_exp(err_msg,
|
|
72
|
+
ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
|
|
73
|
+
kwargs_compute_element_dict = {key_str: ComputeElement(compute_element_info=compute_element_info)
|
|
74
|
+
for key_str, compute_element_info in kwargs_dict.items()}
|
|
75
|
+
return kwargs_compute_element_dict
|
|
76
|
+
|
|
@@ -1,152 +1,156 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
import mindspore
|
|
4
|
-
import torch
|
|
5
|
-
from mindspore import ops
|
|
6
|
-
|
|
7
|
-
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
8
|
-
from msprobe.core.common.const import Const, MsCompareConst
|
|
9
|
-
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
10
|
-
from msprobe.
|
|
11
|
-
from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
self.
|
|
24
|
-
self.
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
(MsCompareConst.MINT, Const.
|
|
29
|
-
(MsCompareConst.
|
|
30
|
-
(MsCompareConst.MINT_FUNCTIONAL, Const.
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
mindspore.mint
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
if
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
import mindspore
|
|
4
|
+
import torch
|
|
5
|
+
from mindspore import ops
|
|
6
|
+
|
|
7
|
+
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
8
|
+
from msprobe.core.common.const import Const, MsCompareConst
|
|
9
|
+
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
10
|
+
from msprobe.mindspore.common.log import logger
|
|
11
|
+
from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
|
|
12
|
+
from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ApiInputAggregation:
|
|
16
|
+
def __init__(self, inputs, kwargs, gradient_inputs) -> None:
|
|
17
|
+
'''
|
|
18
|
+
Args:
|
|
19
|
+
inputs: List[ComputeElement]
|
|
20
|
+
kwargs: dict{str: ComputeElement}
|
|
21
|
+
gradient_inputs: Union[List[ComputeElement], None]
|
|
22
|
+
'''
|
|
23
|
+
self.inputs = inputs
|
|
24
|
+
self.kwargs = kwargs
|
|
25
|
+
self.gradient_inputs = gradient_inputs
|
|
26
|
+
|
|
27
|
+
api_parent_module_mapping = {
|
|
28
|
+
(MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint,
|
|
29
|
+
(MsCompareConst.MINT, Const.PT_FRAMEWORK): torch,
|
|
30
|
+
(MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
|
|
31
|
+
(MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ApiRunner:
|
|
36
|
+
def __call__(self, api_input_aggregation, api_name_str, forward_or_backward=Const.FORWARD,
|
|
37
|
+
api_platform=Const.MS_FRAMEWORK):
|
|
38
|
+
'''
|
|
39
|
+
Args:
|
|
40
|
+
api_input_aggregation: ApiInputAggregation
|
|
41
|
+
api_name_str: str, e.g. "MintFunctional.relu.0"
|
|
42
|
+
forward_or_backward: str, Union["forward", "backward"]
|
|
43
|
+
api_platform: str, Union["mindspore", "torch"]
|
|
44
|
+
|
|
45
|
+
Return:
|
|
46
|
+
outputs: list[ComputeElement]
|
|
47
|
+
|
|
48
|
+
Description:
|
|
49
|
+
run mindspore.mint/torch api
|
|
50
|
+
'''
|
|
51
|
+
api_type_str, api_sub_name = self.get_info_from_name(api_name_str)
|
|
52
|
+
api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform)
|
|
53
|
+
|
|
54
|
+
return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform)
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def get_info_from_name(api_name_str):
|
|
58
|
+
'''
|
|
59
|
+
Args:
|
|
60
|
+
api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
|
|
61
|
+
|
|
62
|
+
Return:
|
|
63
|
+
api_type_str: str, Union["MintFunctional", "Mint"]
|
|
64
|
+
api_sub_name: str, e.g. "relu"
|
|
65
|
+
'''
|
|
66
|
+
api_name_list = api_name_str.split(Const.SEP)
|
|
67
|
+
if len(api_name_list) != 3:
|
|
68
|
+
err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
|
|
69
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
70
|
+
api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
|
|
71
|
+
if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL]:
|
|
72
|
+
err_msg = f"ApiRunner.get_info_from_name failed: not mint or mint.nn.functional api"
|
|
73
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
74
|
+
|
|
75
|
+
return api_type_str, api_sub_name
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def get_api_instance(api_type_str, api_sub_name, api_platform):
|
|
79
|
+
'''
|
|
80
|
+
Args:
|
|
81
|
+
api_type_str: str, Union["MintFunctional", "Mint"]
|
|
82
|
+
api_sub_name: str, e.g. "relu"
|
|
83
|
+
api_platform: str: Union["mindpore", "torch"]
|
|
84
|
+
|
|
85
|
+
Return:
|
|
86
|
+
api_instance: function object
|
|
87
|
+
|
|
88
|
+
Description:
|
|
89
|
+
get mindspore.mint/torch api fucntion
|
|
90
|
+
mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
|
|
91
|
+
mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
|
|
92
|
+
'''
|
|
93
|
+
|
|
94
|
+
api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
|
|
95
|
+
module_str = "mindspore.mint." if api_platform == Const.MS_FRAMEWORK else "torch."
|
|
96
|
+
submodule_str = "nn.functional." if api_type_str == MsCompareConst.MINT_FUNCTIONAL else ""
|
|
97
|
+
full_api_name = module_str + submodule_str + api_sub_name
|
|
98
|
+
if not hasattr(api_parent_module, api_sub_name):
|
|
99
|
+
err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
|
|
100
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
|
|
101
|
+
|
|
102
|
+
api_instance = getattr(api_parent_module, api_sub_name)
|
|
103
|
+
if not callable(api_instance):
|
|
104
|
+
err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not callable"
|
|
105
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
|
|
106
|
+
|
|
107
|
+
return api_instance
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform):
|
|
111
|
+
inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
|
|
112
|
+
for compute_element in api_input_aggregation.inputs)
|
|
113
|
+
kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform)
|
|
114
|
+
for key, value in api_input_aggregation.kwargs.items()}
|
|
115
|
+
gradient_inputs = api_input_aggregation.gradient_inputs
|
|
116
|
+
|
|
117
|
+
if forward_or_backward == Const.FORWARD:
|
|
118
|
+
forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
|
|
119
|
+
forward_result_tuple = convert_to_tuple(forward_result)
|
|
120
|
+
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
|
|
121
|
+
else:
|
|
122
|
+
if gradient_inputs is None:
|
|
123
|
+
err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing"
|
|
124
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
125
|
+
gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
|
|
126
|
+
for compute_element in gradient_inputs)
|
|
127
|
+
if api_platform == Const.MS_FRAMEWORK:
|
|
128
|
+
if len(gradient_inputs) == 1:
|
|
129
|
+
gradient_inputs = gradient_inputs[0]
|
|
130
|
+
def api_with_kwargs(*forward_inputs):
|
|
131
|
+
return api_instance(*forward_inputs, **kwargs)
|
|
132
|
+
grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs)
|
|
133
|
+
backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
|
|
134
|
+
backward_result_tuple = convert_to_tuple(backward_result)
|
|
135
|
+
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
|
|
136
|
+
else:
|
|
137
|
+
#set requires_grad
|
|
138
|
+
requires_grad_index = []
|
|
139
|
+
for index, tensor in enumerate(inputs):
|
|
140
|
+
if isinstance(tensor, torch.Tensor) and \
|
|
141
|
+
torch_dtype_to_dtype_str.get(tensor.dtype) in float_dtype_str_list:
|
|
142
|
+
setattr(tensor, "requires_grad", True)
|
|
143
|
+
requires_grad_index.append(index)
|
|
144
|
+
forward_results = api_instance(*inputs, **kwargs)
|
|
145
|
+
forward_results = convert_to_tuple(forward_results)
|
|
146
|
+
for forward_res, gradient_in in zip(forward_results, gradient_inputs):
|
|
147
|
+
forward_res.backward(gradient_in)
|
|
148
|
+
backward_result_list = []
|
|
149
|
+
for index in requires_grad_index:
|
|
150
|
+
backward_result_list.append(getattr(inputs[index], "grad"))
|
|
151
|
+
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_list]
|
|
152
|
+
|
|
153
|
+
return res_compute_element_list
|
|
154
|
+
|
|
155
|
+
|
|
152
156
|
api_runner = ApiRunner()
|