mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +131 -237
- msprobe/__init__.py +16 -1
- msprobe/{config/config.json → config.json} +47 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +58 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +402 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +523 -283
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +86 -69
- msprobe/core/common/utils.py +371 -616
- msprobe/core/common_config.py +78 -71
- msprobe/core/compare/acc_compare.py +472 -298
- msprobe/core/compare/check.py +180 -95
- msprobe/core/compare/compare_cli.py +69 -49
- msprobe/core/compare/highlight.py +259 -222
- msprobe/core/compare/multiprocessing_compute.py +174 -149
- msprobe/core/compare/npy_compare.py +310 -295
- msprobe/core/compare/utils.py +464 -429
- msprobe/core/data_dump/data_collector.py +153 -144
- msprobe/core/data_dump/data_processor/base.py +337 -293
- msprobe/core/data_dump/data_processor/factory.py +76 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
- msprobe/core/data_dump/json_writer.py +117 -116
- msprobe/core/data_dump/scope.py +194 -178
- msprobe/core/grad_probe/constant.py +74 -70
- msprobe/core/grad_probe/grad_compare.py +170 -175
- msprobe/core/grad_probe/utils.py +77 -52
- msprobe/docs/01.installation.md +99 -0
- msprobe/docs/02.config_introduction.md +137 -0
- msprobe/docs/03.config_examples.md +237 -0
- msprobe/docs/04.acl_config_examples.md +78 -0
- msprobe/docs/05.data_dump_PyTorch.md +326 -0
- msprobe/docs/06.data_dump_MindSpore.md +285 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
- msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
- msprobe/docs/FAQ.md +189 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +2 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +58 -34
- msprobe/mindspore/common/const.py +108 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +97 -57
- msprobe/mindspore/compare/distributed_compare.py +62 -75
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +357 -117
- msprobe/mindspore/compare/ms_graph_compare.py +364 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +69 -74
- msprobe/mindspore/debugger/precision_debugger.py +150 -107
- msprobe/mindspore/dump/dump_tool_factory.py +50 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
- msprobe/mindspore/dump/jit_dump.py +96 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
- msprobe/mindspore/free_benchmark/common/config.py +27 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
- msprobe/mindspore/free_benchmark/common/utils.py +85 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
- msprobe/mindspore/grad_probe/global_context.py +100 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +297 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +23 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +30 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
- msprobe/pytorch/bench_functions/linear.py +27 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
- msprobe/pytorch/bench_functions/rms_norm.py +30 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
- msprobe/pytorch/bench_functions/swiglu.py +70 -55
- msprobe/pytorch/common/__init__.py +17 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +33 -32
- msprobe/pytorch/common/parse_json.py +54 -39
- msprobe/pytorch/common/utils.py +310 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +49 -33
- msprobe/pytorch/compare/pt_compare.py +82 -40
- msprobe/pytorch/debugger/debugger_config.py +108 -95
- msprobe/pytorch/debugger/precision_debugger.py +173 -125
- msprobe/pytorch/free_benchmark/__init__.py +23 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +65 -37
- msprobe/pytorch/free_benchmark/common/params.py +144 -129
- msprobe/pytorch/free_benchmark/common/utils.py +118 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
- msprobe/pytorch/free_benchmark/main.py +120 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
- msprobe/pytorch/function_factory.py +91 -75
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +166 -161
- msprobe/pytorch/hook_module/hook_module.py +118 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +28 -29
- msprobe/pytorch/hook_module/wrap_aten.py +111 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
- msprobe/pytorch/hook_module/wrap_functional.py +104 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
- msprobe/pytorch/hook_module/wrap_torch.py +84 -86
- msprobe/pytorch/hook_module/wrap_vf.py +60 -62
- msprobe/pytorch/module_processer.py +153 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +235 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
- msprobe/pytorch/online_dispatch/utils.py +127 -146
- msprobe/pytorch/parse.py +19 -4
- msprobe/pytorch/parse_tool/cli.py +31 -32
- msprobe/pytorch/parse_tool/lib/compare.py +259 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
- msprobe/pytorch/parse_tool/lib/utils.py +320 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +317 -187
- msprobe/pytorch/service.py +311 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
msprobe/pytorch/pt_config.py
CHANGED
|
@@ -1,187 +1,317 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
self.check_config()
|
|
42
|
-
self.
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
self.
|
|
69
|
-
self.
|
|
70
|
-
self.
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
self.
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
self.
|
|
89
|
-
self.
|
|
90
|
-
self.
|
|
91
|
-
self.
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
self.
|
|
95
|
-
self.
|
|
96
|
-
self.
|
|
97
|
-
self.
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
if
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def
|
|
116
|
-
if
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
20
|
+
from msprobe.core.common.file_utils import FileOpen, load_json
|
|
21
|
+
from msprobe.core.common.log import logger
|
|
22
|
+
from msprobe.core.common_config import BaseConfig, CommonConfig
|
|
23
|
+
from msprobe.core.grad_probe.constant import level_adp
|
|
24
|
+
from msprobe.core.grad_probe.utils import check_bounds
|
|
25
|
+
from msprobe.pytorch.free_benchmark.common.enums import (
|
|
26
|
+
DeviceType,
|
|
27
|
+
HandlerType,
|
|
28
|
+
PytorchFreeBenchmarkConst,
|
|
29
|
+
)
|
|
30
|
+
from msprobe.pytorch.hook_module.utils import get_ops
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TensorConfig(BaseConfig):
|
|
34
|
+
def __init__(self, json_config):
|
|
35
|
+
super().__init__(json_config)
|
|
36
|
+
self.online_run_ut = json_config.get("online_run_ut", False)
|
|
37
|
+
self.nfs_path = json_config.get("nfs_path", "")
|
|
38
|
+
self.host = json_config.get("host", "")
|
|
39
|
+
self.port = json_config.get("port", -1)
|
|
40
|
+
self.tls_path = json_config.get("tls_path", "./")
|
|
41
|
+
self.check_config()
|
|
42
|
+
self._check_file_format()
|
|
43
|
+
self._check_tls_path_config()
|
|
44
|
+
|
|
45
|
+
def _check_file_format(self):
|
|
46
|
+
if self.file_format is not None and self.file_format not in ["npy", "bin"]:
|
|
47
|
+
raise Exception("file_format is invalid")
|
|
48
|
+
|
|
49
|
+
def _check_tls_path_config(self):
|
|
50
|
+
if self.tls_path and not os.path.exists(self.tls_path):
|
|
51
|
+
raise Exception("tls_path: %s does not exist" % self.tls_path)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class StatisticsConfig(BaseConfig):
|
|
55
|
+
def __init__(self, json_config):
|
|
56
|
+
super().__init__(json_config)
|
|
57
|
+
self.check_config()
|
|
58
|
+
self._check_summary_mode()
|
|
59
|
+
|
|
60
|
+
def _check_summary_mode(self):
|
|
61
|
+
if self.summary_mode and self.summary_mode not in ["statistics", "md5"]:
|
|
62
|
+
raise Exception("summary_mode is invalid")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class OverflowCheckConfig(BaseConfig):
|
|
66
|
+
def __init__(self, json_config):
|
|
67
|
+
super().__init__(json_config)
|
|
68
|
+
self.overflow_nums = json_config.get("overflow_nums")
|
|
69
|
+
self.check_mode = json_config.get("check_mode")
|
|
70
|
+
self.check_overflow_config()
|
|
71
|
+
|
|
72
|
+
def check_overflow_config(self):
|
|
73
|
+
if self.overflow_nums is not None and not isinstance(self.overflow_nums, int):
|
|
74
|
+
raise Exception("overflow_num is invalid")
|
|
75
|
+
if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]:
|
|
76
|
+
raise Exception("check_mode is invalid")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class FreeBenchmarkCheckConfig(BaseConfig):
|
|
80
|
+
|
|
81
|
+
def __init__(self, json_config):
|
|
82
|
+
super().__init__(json_config)
|
|
83
|
+
self.fuzz_device = json_config.get("fuzz_device", PytorchFreeBenchmarkConst.DEFAULT_DEVICE)
|
|
84
|
+
self.pert_mode = json_config.get("pert_mode", PytorchFreeBenchmarkConst.DEFAULT_MODE)
|
|
85
|
+
self.handler_type = json_config.get("handler_type", PytorchFreeBenchmarkConst.DEFAULT_HANDLER)
|
|
86
|
+
self.fuzz_level = json_config.get("fuzz_level", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_LEVEL)
|
|
87
|
+
self.fuzz_stage = json_config.get("fuzz_stage", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_STAGE)
|
|
88
|
+
self.if_preheat = json_config.get("if_preheat", False)
|
|
89
|
+
self.preheat_step = json_config.get("preheat_step", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
|
|
90
|
+
self.max_sample = json_config.get("max_sample", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
|
|
91
|
+
self.check_freebenchmark_config()
|
|
92
|
+
|
|
93
|
+
def check_freebenchmark_config(self):
|
|
94
|
+
self._check_pert_mode()
|
|
95
|
+
self._check_fuzz_device()
|
|
96
|
+
self._check_handler_type()
|
|
97
|
+
self._check_fuzz_stage()
|
|
98
|
+
self._check_fuzz_level()
|
|
99
|
+
self._check_if_preheat()
|
|
100
|
+
if self.handler_type == HandlerType.FIX:
|
|
101
|
+
self._check_fix_config()
|
|
102
|
+
if self.if_preheat:
|
|
103
|
+
self._check_preheat_config()
|
|
104
|
+
|
|
105
|
+
def _check_pert_mode(self):
|
|
106
|
+
if self.pert_mode not in PytorchFreeBenchmarkConst.PERTURBATION_MODE_LIST:
|
|
107
|
+
msg = (
|
|
108
|
+
f"pert_mode is invalid, it should be one of"
|
|
109
|
+
f" {PytorchFreeBenchmarkConst.PERTURBATION_MODE_LIST}"
|
|
110
|
+
)
|
|
111
|
+
logger.error_log_with_exp(
|
|
112
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def _check_fuzz_device(self):
|
|
116
|
+
if self.fuzz_device not in PytorchFreeBenchmarkConst.DEVICE_LIST:
|
|
117
|
+
msg = (
|
|
118
|
+
f"fuzz_device is invalid, it should be one of"
|
|
119
|
+
f" {PytorchFreeBenchmarkConst.DEVICE_LIST}"
|
|
120
|
+
)
|
|
121
|
+
logger.error_log_with_exp(
|
|
122
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
123
|
+
)
|
|
124
|
+
if (self.fuzz_device == DeviceType.CPU) ^ (
|
|
125
|
+
self.pert_mode in PytorchFreeBenchmarkConst.CPU_MODE_LIST
|
|
126
|
+
):
|
|
127
|
+
msg = (
|
|
128
|
+
f"You neet to and can only set fuzz_device as {DeviceType.CPU} "
|
|
129
|
+
f"when pert_mode in {PytorchFreeBenchmarkConst.CPU_MODE_LIST}"
|
|
130
|
+
)
|
|
131
|
+
logger.error_log_with_exp(
|
|
132
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def _check_handler_type(self):
|
|
136
|
+
if self.handler_type not in PytorchFreeBenchmarkConst.HANDLER_LIST:
|
|
137
|
+
msg = (
|
|
138
|
+
f"handler_type is invalid, it should be one of"
|
|
139
|
+
f" {PytorchFreeBenchmarkConst.HANDLER_LIST}"
|
|
140
|
+
)
|
|
141
|
+
logger.error_log_with_exp(
|
|
142
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def _check_fuzz_stage(self):
|
|
146
|
+
if self.fuzz_stage not in PytorchFreeBenchmarkConst.FUZZ_STAGE_LIST:
|
|
147
|
+
msg = (
|
|
148
|
+
f"fuzz_stage is invalid, it should be one of"
|
|
149
|
+
f" {PytorchFreeBenchmarkConst.FUZZ_STAGE_LIST}"
|
|
150
|
+
)
|
|
151
|
+
logger.error_log_with_exp(
|
|
152
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
def _check_fuzz_level(self):
|
|
156
|
+
if self.fuzz_level not in PytorchFreeBenchmarkConst.FUZZ_LEVEL_LIST:
|
|
157
|
+
msg = (
|
|
158
|
+
f"fuzz_level is invalid, it should be one of"
|
|
159
|
+
f" {PytorchFreeBenchmarkConst.FUZZ_LEVEL_LIST}"
|
|
160
|
+
)
|
|
161
|
+
logger.error_log_with_exp(
|
|
162
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def _check_if_preheat(self):
|
|
166
|
+
if not isinstance(self.if_preheat, bool):
|
|
167
|
+
msg = "if_preheat is invalid, it should be a boolean"
|
|
168
|
+
logger.error_log_with_exp(
|
|
169
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
def _check_preheat_config(self):
|
|
173
|
+
if not isinstance(self.preheat_step, int):
|
|
174
|
+
msg = "preheat_step is invalid, it should be an integer"
|
|
175
|
+
logger.error_log_with_exp(
|
|
176
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
177
|
+
)
|
|
178
|
+
if self.preheat_step <= 0:
|
|
179
|
+
msg = "preheat_step must be greater than 0"
|
|
180
|
+
logger.error_log_with_exp(
|
|
181
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
182
|
+
)
|
|
183
|
+
if not isinstance(self.max_sample, int):
|
|
184
|
+
msg = "max_sample is invalid, it should be an integer"
|
|
185
|
+
logger.error_log_with_exp(
|
|
186
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
187
|
+
)
|
|
188
|
+
if self.max_sample <= 0:
|
|
189
|
+
msg = "max_sample must be greater than 0"
|
|
190
|
+
logger.error_log_with_exp(
|
|
191
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def _check_fix_config(self):
|
|
195
|
+
if self.if_preheat:
|
|
196
|
+
msg = f"Preheating is not supported for {HandlerType.FIX} handler type"
|
|
197
|
+
logger.error_log_with_exp(
|
|
198
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
199
|
+
)
|
|
200
|
+
if self.fuzz_stage not in PytorchFreeBenchmarkConst.FIX_STAGE_LIST:
|
|
201
|
+
msg = (
|
|
202
|
+
f"The fuzz_stage when opening {HandlerType.FIX} handler must be one of "
|
|
203
|
+
f"{PytorchFreeBenchmarkConst.FIX_STAGE_LIST}"
|
|
204
|
+
)
|
|
205
|
+
logger.error_log_with_exp(
|
|
206
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
207
|
+
)
|
|
208
|
+
if self.pert_mode not in PytorchFreeBenchmarkConst.FIX_MODE_LIST:
|
|
209
|
+
msg = (
|
|
210
|
+
f"The pert_mode when opening {HandlerType.FIX} handler must be one of "
|
|
211
|
+
f"{PytorchFreeBenchmarkConst.FIX_MODE_LIST}"
|
|
212
|
+
)
|
|
213
|
+
logger.error_log_with_exp(
|
|
214
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class RunUTConfig(BaseConfig):
|
|
219
|
+
WrapApi = get_ops()
|
|
220
|
+
|
|
221
|
+
def __init__(self, json_config):
|
|
222
|
+
super().__init__(json_config)
|
|
223
|
+
self.white_list = json_config.get("white_list", Const.DEFAULT_LIST)
|
|
224
|
+
self.black_list = json_config.get("black_list", Const.DEFAULT_LIST)
|
|
225
|
+
self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH)
|
|
226
|
+
self.is_online = json_config.get("is_online", False)
|
|
227
|
+
self.nfs_path = json_config.get("nfs_path", "")
|
|
228
|
+
self.host = json_config.get("host", "")
|
|
229
|
+
self.port = json_config.get("port", -1)
|
|
230
|
+
self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
|
|
231
|
+
self.tls_path = json_config.get("tls_path", "./")
|
|
232
|
+
self.check_run_ut_config()
|
|
233
|
+
|
|
234
|
+
@classmethod
|
|
235
|
+
def check_filter_list_config(cls, key, filter_list):
|
|
236
|
+
if not isinstance(filter_list, list):
|
|
237
|
+
raise Exception("%s must be a list type" % key)
|
|
238
|
+
if not all(isinstance(item, str) for item in filter_list):
|
|
239
|
+
raise Exception("All elements in %s must be string type" % key)
|
|
240
|
+
invalid_api = [item for item in filter_list if item not in cls.WrapApi]
|
|
241
|
+
if invalid_api:
|
|
242
|
+
raise Exception("Invalid api in %s: %s" % (key, invalid_api))
|
|
243
|
+
|
|
244
|
+
@classmethod
|
|
245
|
+
def check_error_data_path_config(cls, error_data_path):
|
|
246
|
+
if not os.path.exists(error_data_path):
|
|
247
|
+
raise Exception("error_data_path: %s does not exist" % error_data_path)
|
|
248
|
+
|
|
249
|
+
@classmethod
|
|
250
|
+
def check_nfs_path_config(cls, nfs_path):
|
|
251
|
+
if nfs_path and not os.path.exists(nfs_path):
|
|
252
|
+
raise Exception("nfs_path: %s does not exist" % nfs_path)
|
|
253
|
+
|
|
254
|
+
@classmethod
|
|
255
|
+
def check_tls_path_config(cls, tls_path):
|
|
256
|
+
if tls_path and not os.path.exists(tls_path):
|
|
257
|
+
raise Exception("tls_path: %s does not exist" % tls_path)
|
|
258
|
+
|
|
259
|
+
def check_run_ut_config(self):
|
|
260
|
+
RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
|
|
261
|
+
RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list)
|
|
262
|
+
RunUTConfig.check_error_data_path_config(self.error_data_path)
|
|
263
|
+
RunUTConfig.check_nfs_path_config(self.nfs_path)
|
|
264
|
+
RunUTConfig.check_tls_path_config(self.tls_path)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class GradToolConfig(BaseConfig):
|
|
268
|
+
def __init__(self, json_config):
|
|
269
|
+
super().__init__(json_config)
|
|
270
|
+
self.grad_level = json_config.get("grad_level", "L1")
|
|
271
|
+
self.param_list = json_config.get("param_list", [])
|
|
272
|
+
self.bounds = json_config.get("bounds", [-1, 0, 1])
|
|
273
|
+
self._check_config()
|
|
274
|
+
|
|
275
|
+
def _check_config(self):
|
|
276
|
+
if self.grad_level not in level_adp.keys():
|
|
277
|
+
raise Exception(f"grad_level must be one of {level_adp.keys()}")
|
|
278
|
+
if not isinstance(self.param_list, list):
|
|
279
|
+
raise Exception(f"param_list must be a list")
|
|
280
|
+
check_bounds(self.bounds)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def parse_task_config(task, json_config):
|
|
284
|
+
default_dic = {}
|
|
285
|
+
if task == Const.TENSOR:
|
|
286
|
+
config_dic = json_config.get(Const.TENSOR, default_dic)
|
|
287
|
+
return TensorConfig(config_dic)
|
|
288
|
+
elif task == Const.STATISTICS:
|
|
289
|
+
config_dic = json_config.get(Const.STATISTICS, default_dic)
|
|
290
|
+
return StatisticsConfig(config_dic)
|
|
291
|
+
elif task == Const.OVERFLOW_CHECK:
|
|
292
|
+
config_dic = json_config.get(Const.OVERFLOW_CHECK, default_dic)
|
|
293
|
+
return OverflowCheckConfig(config_dic)
|
|
294
|
+
elif task == Const.FREE_BENCHMARK:
|
|
295
|
+
config_dic = json_config.get(Const.FREE_BENCHMARK, default_dic)
|
|
296
|
+
return FreeBenchmarkCheckConfig(config_dic)
|
|
297
|
+
elif task == Const.RUN_UT:
|
|
298
|
+
config_dic = json_config.get(Const.RUN_UT, default_dic)
|
|
299
|
+
return RunUTConfig(config_dic)
|
|
300
|
+
elif task == Const.GRAD_PROBE:
|
|
301
|
+
config_dic = json_config.get(Const.GRAD_PROBE, default_dic)
|
|
302
|
+
return GradToolConfig(config_dic)
|
|
303
|
+
else:
|
|
304
|
+
return StatisticsConfig(default_dic)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def parse_json_config(json_file_path, task):
|
|
308
|
+
if not json_file_path:
|
|
309
|
+
config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
310
|
+
json_file_path = os.path.join(config_dir, "config.json")
|
|
311
|
+
json_config = load_json(json_file_path)
|
|
312
|
+
common_config = CommonConfig(json_config)
|
|
313
|
+
if task:
|
|
314
|
+
task_config = parse_task_config(task, json_config)
|
|
315
|
+
else:
|
|
316
|
+
task_config = parse_task_config(common_config.task, json_config)
|
|
317
|
+
return common_config, task_config
|