mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +131 -237
- msprobe/__init__.py +16 -1
- msprobe/{config/config.json → config.json} +47 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +58 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +402 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +523 -283
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +86 -69
- msprobe/core/common/utils.py +371 -616
- msprobe/core/common_config.py +78 -71
- msprobe/core/compare/acc_compare.py +472 -298
- msprobe/core/compare/check.py +180 -95
- msprobe/core/compare/compare_cli.py +69 -49
- msprobe/core/compare/highlight.py +259 -222
- msprobe/core/compare/multiprocessing_compute.py +174 -149
- msprobe/core/compare/npy_compare.py +310 -295
- msprobe/core/compare/utils.py +464 -429
- msprobe/core/data_dump/data_collector.py +153 -144
- msprobe/core/data_dump/data_processor/base.py +337 -293
- msprobe/core/data_dump/data_processor/factory.py +76 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
- msprobe/core/data_dump/json_writer.py +117 -116
- msprobe/core/data_dump/scope.py +194 -178
- msprobe/core/grad_probe/constant.py +74 -70
- msprobe/core/grad_probe/grad_compare.py +170 -175
- msprobe/core/grad_probe/utils.py +77 -52
- msprobe/docs/01.installation.md +99 -0
- msprobe/docs/02.config_introduction.md +137 -0
- msprobe/docs/03.config_examples.md +237 -0
- msprobe/docs/04.acl_config_examples.md +78 -0
- msprobe/docs/05.data_dump_PyTorch.md +326 -0
- msprobe/docs/06.data_dump_MindSpore.md +285 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
- msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
- msprobe/docs/FAQ.md +189 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +2 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +58 -34
- msprobe/mindspore/common/const.py +108 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +97 -57
- msprobe/mindspore/compare/distributed_compare.py +62 -75
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +357 -117
- msprobe/mindspore/compare/ms_graph_compare.py +364 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +69 -74
- msprobe/mindspore/debugger/precision_debugger.py +150 -107
- msprobe/mindspore/dump/dump_tool_factory.py +50 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
- msprobe/mindspore/dump/jit_dump.py +96 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
- msprobe/mindspore/free_benchmark/common/config.py +27 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
- msprobe/mindspore/free_benchmark/common/utils.py +85 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
- msprobe/mindspore/grad_probe/global_context.py +100 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +297 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +23 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +30 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
- msprobe/pytorch/bench_functions/linear.py +27 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
- msprobe/pytorch/bench_functions/rms_norm.py +30 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
- msprobe/pytorch/bench_functions/swiglu.py +70 -55
- msprobe/pytorch/common/__init__.py +17 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +33 -32
- msprobe/pytorch/common/parse_json.py +54 -39
- msprobe/pytorch/common/utils.py +310 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +49 -33
- msprobe/pytorch/compare/pt_compare.py +82 -40
- msprobe/pytorch/debugger/debugger_config.py +108 -95
- msprobe/pytorch/debugger/precision_debugger.py +173 -125
- msprobe/pytorch/free_benchmark/__init__.py +23 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +65 -37
- msprobe/pytorch/free_benchmark/common/params.py +144 -129
- msprobe/pytorch/free_benchmark/common/utils.py +118 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
- msprobe/pytorch/free_benchmark/main.py +120 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
- msprobe/pytorch/function_factory.py +91 -75
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +166 -161
- msprobe/pytorch/hook_module/hook_module.py +118 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +28 -29
- msprobe/pytorch/hook_module/wrap_aten.py +111 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
- msprobe/pytorch/hook_module/wrap_functional.py +104 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
- msprobe/pytorch/hook_module/wrap_torch.py +84 -86
- msprobe/pytorch/hook_module/wrap_vf.py +60 -62
- msprobe/pytorch/module_processer.py +153 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +235 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
- msprobe/pytorch/online_dispatch/utils.py +127 -146
- msprobe/pytorch/parse.py +19 -4
- msprobe/pytorch/parse_tool/cli.py +31 -32
- msprobe/pytorch/parse_tool/lib/compare.py +259 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
- msprobe/pytorch/parse_tool/lib/utils.py +320 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +317 -187
- msprobe/pytorch/service.py +311 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -1,125 +1,173 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
@
|
|
86
|
-
def
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
if
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections import namedtuple
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from msprobe.core.common.const import Const, FileCheckConst, MsgConst
|
|
20
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
21
|
+
from msprobe.core.common.file_utils import FileChecker
|
|
22
|
+
from msprobe.core.common.utils import get_real_step_or_rank
|
|
23
|
+
from msprobe.pytorch.common.log import logger
|
|
24
|
+
from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
|
|
25
|
+
from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
|
|
26
|
+
from msprobe.pytorch.pt_config import parse_json_config
|
|
27
|
+
from msprobe.pytorch.service import Service
|
|
28
|
+
from torch.utils.data import dataloader
|
|
29
|
+
|
|
30
|
+
ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task",
|
|
31
|
+
"dump_path", "level", "model"])
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PrecisionDebugger:
|
|
35
|
+
_instance = None
|
|
36
|
+
tasks_not_need_debugger = [Const.GRAD_PROBE]
|
|
37
|
+
|
|
38
|
+
def __new__(cls, *args, **kwargs):
|
|
39
|
+
if cls._instance is None:
|
|
40
|
+
cls._instance = super(PrecisionDebugger, cls).__new__(cls)
|
|
41
|
+
cls._instance.config = None
|
|
42
|
+
cls._instance.enable_dataloader = False
|
|
43
|
+
return cls._instance
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
config_path=None,
|
|
48
|
+
task=None,
|
|
49
|
+
dump_path=None,
|
|
50
|
+
level=None,
|
|
51
|
+
model=None,
|
|
52
|
+
step=None,
|
|
53
|
+
):
|
|
54
|
+
if not hasattr(self, "initialized"):
|
|
55
|
+
config_params = ConfigParameters(config_path,
|
|
56
|
+
task,
|
|
57
|
+
dump_path,
|
|
58
|
+
level,
|
|
59
|
+
model)
|
|
60
|
+
self.check_input_params(config_params)
|
|
61
|
+
|
|
62
|
+
self.api_origin = False
|
|
63
|
+
self.initialized = True
|
|
64
|
+
self.model = model
|
|
65
|
+
common_config, task_config = parse_json_config(config_path, task)
|
|
66
|
+
self.task = task if task else common_config.task
|
|
67
|
+
if self.task == Const.GRAD_PROBE:
|
|
68
|
+
self.gm = GradientMonitor(common_config, task_config)
|
|
69
|
+
return
|
|
70
|
+
if step:
|
|
71
|
+
common_config.step = get_real_step_or_rank(step, Const.STEP)
|
|
72
|
+
self.config = DebuggerConfig(
|
|
73
|
+
common_config, task_config, task, dump_path, level
|
|
74
|
+
)
|
|
75
|
+
self.service = Service(self.config)
|
|
76
|
+
self.enable_dataloader = self.config.enable_dataloader
|
|
77
|
+
if self.enable_dataloader:
|
|
78
|
+
logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
|
|
79
|
+
dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__)
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def instance(self):
|
|
83
|
+
return self._instance
|
|
84
|
+
|
|
85
|
+
@staticmethod
|
|
86
|
+
def check_input_params(args):
|
|
87
|
+
if args.config_path is not None:
|
|
88
|
+
if not isinstance(args.config_path, str):
|
|
89
|
+
raise MsprobeException(
|
|
90
|
+
MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
|
|
91
|
+
file_checker = FileChecker(
|
|
92
|
+
file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
93
|
+
file_checker.common_check()
|
|
94
|
+
|
|
95
|
+
if args.task is not None and args.task not in Const.TASK_LIST:
|
|
96
|
+
raise MsprobeException(
|
|
97
|
+
MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}")
|
|
98
|
+
|
|
99
|
+
if args.dump_path is not None:
|
|
100
|
+
if not isinstance(args.dump_path, str):
|
|
101
|
+
raise MsprobeException(
|
|
102
|
+
MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string")
|
|
103
|
+
|
|
104
|
+
if args.level is not None and args.level not in Const.LEVEL_LIST:
|
|
105
|
+
raise MsprobeException(
|
|
106
|
+
MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
|
|
107
|
+
|
|
108
|
+
if args.model is not None and not isinstance(args.model, torch.nn.Module):
|
|
109
|
+
raise MsprobeException(
|
|
110
|
+
MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def start(cls, model=None):
|
|
114
|
+
instance = cls._instance
|
|
115
|
+
if not instance:
|
|
116
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
117
|
+
if instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
118
|
+
return
|
|
119
|
+
instance.config.check_model(instance, model)
|
|
120
|
+
if instance.enable_dataloader:
|
|
121
|
+
logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
|
|
122
|
+
else:
|
|
123
|
+
instance.service.start(instance.model, instance.api_origin)
|
|
124
|
+
instance.api_origin = False
|
|
125
|
+
|
|
126
|
+
# 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump
|
|
127
|
+
@classmethod
|
|
128
|
+
def forward_backward_dump_end(cls):
|
|
129
|
+
instance = cls._instance
|
|
130
|
+
instance.service.forward_backward_dump_end()
|
|
131
|
+
instance.api_origin = True
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def stop(cls):
|
|
135
|
+
instance = cls._instance
|
|
136
|
+
if not instance:
|
|
137
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
138
|
+
if instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
139
|
+
return
|
|
140
|
+
if instance.enable_dataloader:
|
|
141
|
+
logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
|
|
142
|
+
else:
|
|
143
|
+
instance.service.stop()
|
|
144
|
+
|
|
145
|
+
@classmethod
|
|
146
|
+
def step(cls):
|
|
147
|
+
if not cls._instance:
|
|
148
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
149
|
+
if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
150
|
+
return
|
|
151
|
+
cls._instance.service.step()
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
def monitor(cls, model):
|
|
155
|
+
if not cls._instance:
|
|
156
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
157
|
+
if cls._instance.task != Const.GRAD_PROBE:
|
|
158
|
+
return
|
|
159
|
+
cls._instance.gm.monitor(model)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def iter_tracer(func):
|
|
163
|
+
def func_wrapper(*args, **kwargs):
|
|
164
|
+
debugger_instance = PrecisionDebugger.instance
|
|
165
|
+
debugger_instance.enable_dataloader = False
|
|
166
|
+
if not debugger_instance.service.first_start:
|
|
167
|
+
debugger_instance.stop()
|
|
168
|
+
debugger_instance.step()
|
|
169
|
+
result = func(*args, **kwargs)
|
|
170
|
+
debugger_instance.start()
|
|
171
|
+
debugger_instance.enable_dataloader = True
|
|
172
|
+
return result
|
|
173
|
+
return func_wrapper
|
|
@@ -1,8 +1,23 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
__all__ = ["FreeBenchmarkCheck", "UnequalRow"]
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
20
|
+
from msprobe.pytorch.common.log import logger
|
|
21
|
+
|
|
22
|
+
from .common.params import UnequalRow
|
|
23
|
+
from .main import FreeBenchmarkCheck
|
|
@@ -1,70 +1,70 @@
|
|
|
1
|
-
from typing import Dict
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import torch
|
|
5
|
-
from msprobe.pytorch.free_benchmark.common.enums import FuzzThreshold
|
|
6
|
-
from msprobe.pytorch.free_benchmark.common.params import BenchmarkThd
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class CommonField:
|
|
10
|
-
DEVICE = "device"
|
|
11
|
-
META = "meta"
|
|
12
|
-
FUZZ_TENSOR = "fuzz_tensor"
|
|
13
|
-
REQUIRES_GRAD = "requires_grad"
|
|
14
|
-
HOLD_PLACE = "hold_place"
|
|
15
|
-
DISTRIBUTED_OP = "torch.distributed"
|
|
16
|
-
GRADSAVER = "grad_saver"
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class ThresholdConfig:
|
|
20
|
-
PERTURBATION_VALUE_DICT: Dict = {
|
|
21
|
-
torch.bfloat16: FuzzThreshold.BF16_THD,
|
|
22
|
-
torch.float16: FuzzThreshold.F16_THD,
|
|
23
|
-
torch.float32: FuzzThreshold.F32_THD,
|
|
24
|
-
torch.float64: FuzzThreshold.F64_THD,
|
|
25
|
-
}
|
|
26
|
-
|
|
27
|
-
ABS_TOL_VALUE_DICT: Dict = {
|
|
28
|
-
torch.bfloat16: FuzzThreshold.BF16_THD,
|
|
29
|
-
torch.float16: FuzzThreshold.F16_THD,
|
|
30
|
-
torch.float32: FuzzThreshold.F32_THD,
|
|
31
|
-
torch.float64: FuzzThreshold.F64_THD,
|
|
32
|
-
}
|
|
33
|
-
|
|
34
|
-
# bit翻转需要匹配到等长或更长的整型
|
|
35
|
-
PERTURBATION_BIT_DICT = {
|
|
36
|
-
torch.bfloat16: torch.int16,
|
|
37
|
-
torch.float16: torch.int16,
|
|
38
|
-
torch.float32: torch.int32,
|
|
39
|
-
torch.float64: torch.int64,
|
|
40
|
-
}
|
|
41
|
-
|
|
42
|
-
# 输入噪声下界
|
|
43
|
-
NOISE_INPUT_LOWER_BOUND = 1e-8
|
|
44
|
-
COMP_CONSISTENT = 1.0
|
|
45
|
-
COMP_NAN = np.nan
|
|
46
|
-
SYMBOL_FLIPPING = "symbol_flipping"
|
|
47
|
-
BACKWARD_OUTPUT_LOWER_BOUND = 1e-3
|
|
48
|
-
SMALL_VALUE = 1.0
|
|
49
|
-
# 预热初始阈值
|
|
50
|
-
PREHEAT_INITIAL_THD = 2.05
|
|
51
|
-
API_THD_STEP = 2.0
|
|
52
|
-
|
|
53
|
-
DTYPE_PER_THD = {
|
|
54
|
-
torch.float16: 1.002,
|
|
55
|
-
torch.bfloat16: 1.004,
|
|
56
|
-
torch.float32: 1.0002,
|
|
57
|
-
}
|
|
58
|
-
BENCHMARK_THD_DICT = {
|
|
59
|
-
torch.float32: BenchmarkThd(2**-14, 1.0, 2**-14, 1e-4),
|
|
60
|
-
torch.float16: BenchmarkThd(2**-11, 1.0, 2**-11, 1e-4),
|
|
61
|
-
torch.bfloat16: BenchmarkThd(2**-8, 1.0, 2**-8, 1e-4),
|
|
62
|
-
}
|
|
63
|
-
|
|
64
|
-
TENSOR_SPLIT_MAX_CHUNK = 128
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
class PreheatConfig:
|
|
68
|
-
IF_PREHEAT = "if_preheat"
|
|
69
|
-
PREHEAT_STEP = "preheat_step"
|
|
70
|
-
MAX_SAMPLE = "max_sample"
|
|
1
|
+
from typing import Dict
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from msprobe.pytorch.free_benchmark.common.enums import FuzzThreshold
|
|
6
|
+
from msprobe.pytorch.free_benchmark.common.params import BenchmarkThd
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CommonField:
|
|
10
|
+
DEVICE = "device"
|
|
11
|
+
META = "meta"
|
|
12
|
+
FUZZ_TENSOR = "fuzz_tensor"
|
|
13
|
+
REQUIRES_GRAD = "requires_grad"
|
|
14
|
+
HOLD_PLACE = "hold_place"
|
|
15
|
+
DISTRIBUTED_OP = "torch.distributed"
|
|
16
|
+
GRADSAVER = "grad_saver"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ThresholdConfig:
|
|
20
|
+
PERTURBATION_VALUE_DICT: Dict = {
|
|
21
|
+
torch.bfloat16: FuzzThreshold.BF16_THD,
|
|
22
|
+
torch.float16: FuzzThreshold.F16_THD,
|
|
23
|
+
torch.float32: FuzzThreshold.F32_THD,
|
|
24
|
+
torch.float64: FuzzThreshold.F64_THD,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
ABS_TOL_VALUE_DICT: Dict = {
|
|
28
|
+
torch.bfloat16: FuzzThreshold.BF16_THD,
|
|
29
|
+
torch.float16: FuzzThreshold.F16_THD,
|
|
30
|
+
torch.float32: FuzzThreshold.F32_THD,
|
|
31
|
+
torch.float64: FuzzThreshold.F64_THD,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
# bit翻转需要匹配到等长或更长的整型
|
|
35
|
+
PERTURBATION_BIT_DICT = {
|
|
36
|
+
torch.bfloat16: torch.int16,
|
|
37
|
+
torch.float16: torch.int16,
|
|
38
|
+
torch.float32: torch.int32,
|
|
39
|
+
torch.float64: torch.int64,
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
# 输入噪声下界
|
|
43
|
+
NOISE_INPUT_LOWER_BOUND = 1e-8
|
|
44
|
+
COMP_CONSISTENT = 1.0
|
|
45
|
+
COMP_NAN = np.nan
|
|
46
|
+
SYMBOL_FLIPPING = "symbol_flipping"
|
|
47
|
+
BACKWARD_OUTPUT_LOWER_BOUND = 1e-3
|
|
48
|
+
SMALL_VALUE = 1.0
|
|
49
|
+
# 预热初始阈值
|
|
50
|
+
PREHEAT_INITIAL_THD = 2.05
|
|
51
|
+
API_THD_STEP = 2.0
|
|
52
|
+
|
|
53
|
+
DTYPE_PER_THD = {
|
|
54
|
+
torch.float16: 1.002,
|
|
55
|
+
torch.bfloat16: 1.004,
|
|
56
|
+
torch.float32: 1.0002,
|
|
57
|
+
}
|
|
58
|
+
BENCHMARK_THD_DICT = {
|
|
59
|
+
torch.float32: BenchmarkThd(2**-14, 1.0, 2**-14, 1e-4),
|
|
60
|
+
torch.float16: BenchmarkThd(2**-11, 1.0, 2**-11, 1e-4),
|
|
61
|
+
torch.bfloat16: BenchmarkThd(2**-8, 1.0, 2**-8, 1e-4),
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
TENSOR_SPLIT_MAX_CHUNK = 128
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class PreheatConfig:
|
|
68
|
+
IF_PREHEAT = "if_preheat"
|
|
69
|
+
PREHEAT_STEP = "preheat_step"
|
|
70
|
+
MAX_SAMPLE = "max_sample"
|
|
@@ -1,72 +1,72 @@
|
|
|
1
|
-
from collections import defaultdict
|
|
2
|
-
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
class PreheatCounter:
|
|
6
|
-
def __init__(self) -> None:
|
|
7
|
-
self.api_called_time: dict = defaultdict(int)
|
|
8
|
-
self.api_sample_time: dict = defaultdict(int)
|
|
9
|
-
self.one_step_used_api: dict = defaultdict(int)
|
|
10
|
-
self.api_thd: dict = defaultdict(dict)
|
|
11
|
-
self.preheat_record: dict = defaultdict(dict)
|
|
12
|
-
self.dtype_map: dict = {}
|
|
13
|
-
self.if_preheat: dict = defaultdict(dict)
|
|
14
|
-
self.step = 0
|
|
15
|
-
|
|
16
|
-
def clear_step(self):
|
|
17
|
-
self.preheat_record.clear()
|
|
18
|
-
self.api_called_time.clear()
|
|
19
|
-
self.api_sample_time.clear()
|
|
20
|
-
|
|
21
|
-
def check_step(self, current_step):
|
|
22
|
-
if current_step != self.step:
|
|
23
|
-
self.clear_step()
|
|
24
|
-
self.step = current_step
|
|
25
|
-
|
|
26
|
-
def add_api_called_time(self, api_name: str):
|
|
27
|
-
self.api_called_time[api_name] += 1
|
|
28
|
-
|
|
29
|
-
def get_api_called_time(self, api_name: str) -> int:
|
|
30
|
-
return self.api_called_time[api_name]
|
|
31
|
-
|
|
32
|
-
def add_api_sample_time(self, api_name: str):
|
|
33
|
-
self.api_sample_time[api_name] += 1
|
|
34
|
-
|
|
35
|
-
def get_api_sample_time(self, api_name: str) -> int:
|
|
36
|
-
return self.api_sample_time[api_name]
|
|
37
|
-
|
|
38
|
-
def add_one_step_used_api(self, api_name: str):
|
|
39
|
-
self.one_step_used_api[api_name] += 1
|
|
40
|
-
|
|
41
|
-
def get_one_step_used_api(self, api_name: str):
|
|
42
|
-
return self.one_step_used_api[api_name]
|
|
43
|
-
|
|
44
|
-
def update_preheat_record(self, api_name, dtype, cmp_result):
|
|
45
|
-
# 记录预热阶段CPU标杆比对的结果
|
|
46
|
-
if str(dtype) not in self.preheat_record[api_name].keys():
|
|
47
|
-
self.preheat_record[api_name][str(dtype)] = list()
|
|
48
|
-
self.preheat_record[api_name][str(dtype)].append(cmp_result)
|
|
49
|
-
self.dtype_map[str(dtype)] = dtype
|
|
50
|
-
|
|
51
|
-
def update_api_thd(self, api_name, dtype, threshold, dthreshold):
|
|
52
|
-
self.api_thd[api_name][str(dtype)] = (
|
|
53
|
-
threshold if threshold > dthreshold else dthreshold
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
def get_api_thd(self, api_name, dtype):
|
|
57
|
-
if not str(dtype) in self.api_thd[api_name]:
|
|
58
|
-
self.api_thd[api_name][str(dtype)] = ThresholdConfig.PREHEAT_INITIAL_THD
|
|
59
|
-
self.dtype_map[str(dtype)] = dtype
|
|
60
|
-
return self.api_thd[api_name][str(dtype)]
|
|
61
|
-
|
|
62
|
-
def set_api_preheat(self, api_name, dtype_str, is_preheat=True):
|
|
63
|
-
# 标记cpu不一致的dtype 不再进行预热
|
|
64
|
-
self.if_preheat[api_name][dtype_str] = is_preheat
|
|
65
|
-
|
|
66
|
-
def get_api_preheat(self, api_name, dtype):
|
|
67
|
-
# 标记cpu不一致的dtype 不再进行预热
|
|
68
|
-
if str(dtype) not in self.if_preheat[api_name]:
|
|
69
|
-
return True
|
|
70
|
-
return self.if_preheat[api_name][str(dtype)]
|
|
71
|
-
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class PreheatCounter:
|
|
6
|
+
def __init__(self) -> None:
|
|
7
|
+
self.api_called_time: dict = defaultdict(int)
|
|
8
|
+
self.api_sample_time: dict = defaultdict(int)
|
|
9
|
+
self.one_step_used_api: dict = defaultdict(int)
|
|
10
|
+
self.api_thd: dict = defaultdict(dict)
|
|
11
|
+
self.preheat_record: dict = defaultdict(dict)
|
|
12
|
+
self.dtype_map: dict = {}
|
|
13
|
+
self.if_preheat: dict = defaultdict(dict)
|
|
14
|
+
self.step = 0
|
|
15
|
+
|
|
16
|
+
def clear_step(self):
|
|
17
|
+
self.preheat_record.clear()
|
|
18
|
+
self.api_called_time.clear()
|
|
19
|
+
self.api_sample_time.clear()
|
|
20
|
+
|
|
21
|
+
def check_step(self, current_step):
|
|
22
|
+
if current_step != self.step:
|
|
23
|
+
self.clear_step()
|
|
24
|
+
self.step = current_step
|
|
25
|
+
|
|
26
|
+
def add_api_called_time(self, api_name: str):
|
|
27
|
+
self.api_called_time[api_name] += 1
|
|
28
|
+
|
|
29
|
+
def get_api_called_time(self, api_name: str) -> int:
|
|
30
|
+
return self.api_called_time[api_name]
|
|
31
|
+
|
|
32
|
+
def add_api_sample_time(self, api_name: str):
|
|
33
|
+
self.api_sample_time[api_name] += 1
|
|
34
|
+
|
|
35
|
+
def get_api_sample_time(self, api_name: str) -> int:
|
|
36
|
+
return self.api_sample_time[api_name]
|
|
37
|
+
|
|
38
|
+
def add_one_step_used_api(self, api_name: str):
|
|
39
|
+
self.one_step_used_api[api_name] += 1
|
|
40
|
+
|
|
41
|
+
def get_one_step_used_api(self, api_name: str):
|
|
42
|
+
return self.one_step_used_api[api_name]
|
|
43
|
+
|
|
44
|
+
def update_preheat_record(self, api_name, dtype, cmp_result):
|
|
45
|
+
# 记录预热阶段CPU标杆比对的结果
|
|
46
|
+
if str(dtype) not in self.preheat_record[api_name].keys():
|
|
47
|
+
self.preheat_record[api_name][str(dtype)] = list()
|
|
48
|
+
self.preheat_record[api_name][str(dtype)].append(cmp_result)
|
|
49
|
+
self.dtype_map[str(dtype)] = dtype
|
|
50
|
+
|
|
51
|
+
def update_api_thd(self, api_name, dtype, threshold, dthreshold):
|
|
52
|
+
self.api_thd[api_name][str(dtype)] = (
|
|
53
|
+
threshold if threshold > dthreshold else dthreshold
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def get_api_thd(self, api_name, dtype):
|
|
57
|
+
if not str(dtype) in self.api_thd[api_name]:
|
|
58
|
+
self.api_thd[api_name][str(dtype)] = ThresholdConfig.PREHEAT_INITIAL_THD
|
|
59
|
+
self.dtype_map[str(dtype)] = dtype
|
|
60
|
+
return self.api_thd[api_name][str(dtype)]
|
|
61
|
+
|
|
62
|
+
def set_api_preheat(self, api_name, dtype_str, is_preheat=True):
|
|
63
|
+
# 标记cpu不一致的dtype 不再进行预热
|
|
64
|
+
self.if_preheat[api_name][dtype_str] = is_preheat
|
|
65
|
+
|
|
66
|
+
def get_api_preheat(self, api_name, dtype):
|
|
67
|
+
# 标记cpu不一致的dtype 不再进行预热
|
|
68
|
+
if str(dtype) not in self.if_preheat[api_name]:
|
|
69
|
+
return True
|
|
70
|
+
return self.if_preheat[api_name][str(dtype)]
|
|
71
|
+
|
|
72
72
|
preheat_counter = PreheatCounter()
|