mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
- mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +101 -182
- msprobe/__init__.py +1 -0
- msprobe/{config/config.json → config.json} +49 -27
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +124 -124
- msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
- msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +341 -241
- msprobe/core/common/exceptions.py +100 -88
- msprobe/core/common/{file_check.py → file_utils.py} +478 -265
- msprobe/core/common/log.py +76 -55
- msprobe/core/common/utils.py +385 -516
- msprobe/core/common_config.py +85 -58
- msprobe/core/compare/acc_compare.py +300 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +223 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
- msprobe/core/compare/utils.py +430 -0
- msprobe/core/data_dump/data_collector.py +154 -140
- msprobe/core/data_dump/data_processor/base.py +314 -245
- msprobe/core/data_dump/data_processor/factory.py +59 -61
- msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
- msprobe/core/data_dump/json_writer.py +96 -116
- msprobe/core/data_dump/scope.py +178 -178
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +171 -0
- msprobe/core/grad_probe/utils.py +64 -0
- msprobe/docs/01.installation.md +89 -0
- msprobe/docs/02.config_introduction.md +165 -0
- msprobe/docs/03.config_examples.md +247 -0
- msprobe/docs/04.acl_config_examples.md +76 -0
- msprobe/docs/05.data_dump_PyTorch.md +198 -0
- msprobe/docs/06.data_dump_MindSpore.md +243 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
- msprobe/docs/17.grad_probe.md +207 -0
- msprobe/docs/FAQ_PyTorch.md +177 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/grad_probe_image-1.png +0 -0
- msprobe/docs/img/grad_probe_image-2.png +0 -0
- msprobe/docs/img/grad_probe_image-3.png +0 -0
- msprobe/docs/img/grad_probe_image-4.png +0 -0
- msprobe/docs/img/grad_probe_image.png +0 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
- msprobe/mindspore/api_accuracy_checker/main.py +9 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +106 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +81 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +219 -0
- msprobe/mindspore/compare/ms_graph_compare.py +348 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +66 -51
- msprobe/mindspore/debugger/precision_debugger.py +126 -32
- msprobe/mindspore/dump/dump_tool_factory.py +35 -38
- msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
- msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
- msprobe/mindspore/dump/jit_dump.py +72 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +90 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +94 -0
- msprobe/mindspore/grad_probe/utils.py +30 -0
- msprobe/mindspore/ms_config.py +128 -78
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +378 -0
- msprobe/mindspore/task_handler_factory.py +24 -21
- msprobe/msprobe.py +105 -67
- msprobe/pytorch/__init__.py +4 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
- msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/__init__.py +2 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +20 -31
- msprobe/pytorch/common/parse_json.py +39 -37
- msprobe/pytorch/common/utils.py +305 -224
- msprobe/pytorch/compare/distributed_compare.py +66 -111
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +34 -36
- msprobe/pytorch/compare/pt_compare.py +50 -0
- msprobe/pytorch/debugger/debugger_config.py +95 -86
- msprobe/pytorch/debugger/precision_debugger.py +125 -95
- msprobe/pytorch/free_benchmark/__init__.py +8 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -67
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +37 -37
- msprobe/pytorch/free_benchmark/common/params.py +129 -129
- msprobe/pytorch/free_benchmark/common/utils.py +102 -98
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
- msprobe/pytorch/free_benchmark/main.py +105 -102
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
- msprobe/pytorch/function_factory.py +76 -0
- msprobe/pytorch/functional/dump_module.py +39 -39
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -161
- msprobe/pytorch/hook_module/hook_module.py +120 -109
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
- msprobe/pytorch/hook_module/utils.py +30 -29
- msprobe/pytorch/hook_module/wrap_aten.py +110 -100
- msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
- msprobe/pytorch/hook_module/wrap_functional.py +105 -108
- msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
- msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
- msprobe/pytorch/hook_module/wrap_torch.py +86 -88
- msprobe/pytorch/hook_module/wrap_vf.py +62 -64
- msprobe/pytorch/module_processer.py +138 -98
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +236 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -273
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
- msprobe/pytorch/online_dispatch/utils.py +130 -187
- msprobe/pytorch/parse.py +4 -4
- msprobe/pytorch/parse_tool/cli.py +32 -32
- msprobe/pytorch/parse_tool/lib/compare.py +260 -259
- msprobe/pytorch/parse_tool/lib/config.py +52 -51
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
- msprobe/pytorch/parse_tool/lib/utils.py +316 -367
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
- msprobe/pytorch/pt_config.py +188 -93
- msprobe/pytorch/service.py +246 -167
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/config/README.md +0 -397
- msprobe/mindspore/doc/dump.md +0 -65
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/dump.md +0 -207
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -1,116 +1,96 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import csv
|
|
3
|
-
|
|
4
|
-
import
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
from msprobe.core.common.
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
self.
|
|
16
|
-
self.
|
|
17
|
-
self.
|
|
18
|
-
self.
|
|
19
|
-
self.
|
|
20
|
-
self.
|
|
21
|
-
self.
|
|
22
|
-
self.
|
|
23
|
-
self.
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
self.
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
def
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
def
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
fcntl.flock(f, fcntl.LOCK_UN)
|
|
98
|
-
|
|
99
|
-
self.cache_data[Const.DATA].clear()
|
|
100
|
-
|
|
101
|
-
def write_stack_info_json(self, file_path):
|
|
102
|
-
with open(file_path, 'w+') as f:
|
|
103
|
-
fcntl.flock(f, fcntl.LOCK_EX)
|
|
104
|
-
json.dump(self.cache_stack, f, indent=1)
|
|
105
|
-
fcntl.flock(f, fcntl.LOCK_UN)
|
|
106
|
-
|
|
107
|
-
def write_construct_info_json(self, file_path):
|
|
108
|
-
with open(file_path, 'w+') as f:
|
|
109
|
-
fcntl.flock(f, fcntl.LOCK_EX)
|
|
110
|
-
json.dump(self.cache_construct, f, indent=1)
|
|
111
|
-
fcntl.flock(f, fcntl.LOCK_UN)
|
|
112
|
-
|
|
113
|
-
def write_json(self):
|
|
114
|
-
self.write_data_json(self.dump_file_path)
|
|
115
|
-
self.write_stack_info_json(self.stack_file_path)
|
|
116
|
-
self.write_construct_info_json(self.construct_file_path)
|
|
1
|
+
import os
|
|
2
|
+
import csv
|
|
3
|
+
|
|
4
|
+
from msprobe.core.common.file_utils import change_mode, FileOpen
|
|
5
|
+
from msprobe.core.common.log import logger
|
|
6
|
+
from msprobe.core.common.const import Const, FileCheckConst
|
|
7
|
+
from msprobe.core.common.file_utils import remove_path, load_json, save_json
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DataWriter:
|
|
11
|
+
|
|
12
|
+
def __init__(self, init_json=None) -> None:
|
|
13
|
+
self.dump_count = 0
|
|
14
|
+
self.init_json = init_json
|
|
15
|
+
self.dump_file_path = None # os.path.join(dump_dir, DataWriter.dump_json_name)
|
|
16
|
+
self.stack_file_path = None # os.path.join(dump_dir, DataWriter.stack_json_name)
|
|
17
|
+
self.construct_file_path = None # os.path.join(dump_dir, DataWriter.construct_json_name)
|
|
18
|
+
self.free_benchmark_file_path = None
|
|
19
|
+
self.dump_tensor_data_dir = None
|
|
20
|
+
self.buffer_size = 1000
|
|
21
|
+
self.cache_data = {Const.DATA: {}}
|
|
22
|
+
self.cache_stack = {}
|
|
23
|
+
self.cache_construct = {}
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def write_data_to_csv(result: list, result_header: tuple, file_path: str):
|
|
27
|
+
if not result:
|
|
28
|
+
return
|
|
29
|
+
is_exists = os.path.exists(file_path)
|
|
30
|
+
append = "a+" if is_exists else "w+"
|
|
31
|
+
with FileOpen(file_path, append) as csv_file:
|
|
32
|
+
spawn_writer = csv.writer(csv_file)
|
|
33
|
+
if not is_exists:
|
|
34
|
+
spawn_writer.writerow(result_header)
|
|
35
|
+
spawn_writer.writerows([result,])
|
|
36
|
+
is_new_file = not is_exists
|
|
37
|
+
if is_new_file:
|
|
38
|
+
change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
39
|
+
|
|
40
|
+
def initialize_json_file(self, **kwargs):
|
|
41
|
+
kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
|
|
42
|
+
save_json(self.dump_file_path, kwargs)
|
|
43
|
+
|
|
44
|
+
empty_dict = {}
|
|
45
|
+
remove_path(self.stack_file_path)
|
|
46
|
+
save_json(self.stack_file_path, empty_dict)
|
|
47
|
+
|
|
48
|
+
remove_path(self.construct_file_path)
|
|
49
|
+
save_json(self.construct_file_path, empty_dict)
|
|
50
|
+
|
|
51
|
+
def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir,
|
|
52
|
+
free_benchmark_file_path):
|
|
53
|
+
self.dump_file_path = dump_file_path
|
|
54
|
+
self.stack_file_path = stack_file_path
|
|
55
|
+
self.construct_file_path = construct_file_path
|
|
56
|
+
self.dump_tensor_data_dir = dump_data_dir
|
|
57
|
+
self.free_benchmark_file_path = free_benchmark_file_path
|
|
58
|
+
|
|
59
|
+
def update_data(self, new_data):
|
|
60
|
+
key = next(iter(new_data.keys())) # assert len(new_data.keys()) == 1
|
|
61
|
+
if key in self.cache_data[Const.DATA]:
|
|
62
|
+
self.cache_data[Const.DATA][key].update(new_data[key])
|
|
63
|
+
else:
|
|
64
|
+
self.cache_data[Const.DATA].update(new_data)
|
|
65
|
+
|
|
66
|
+
def flush_data_when_buffer_is_full(self):
|
|
67
|
+
if len(self.cache_data[Const.DATA]) >= self.buffer_size:
|
|
68
|
+
self.write_data_json(self.dump_file_path)
|
|
69
|
+
|
|
70
|
+
def update_stack(self, new_data):
|
|
71
|
+
self.cache_stack.update(new_data)
|
|
72
|
+
|
|
73
|
+
def update_construct(self, new_data):
|
|
74
|
+
self.cache_construct.update(new_data)
|
|
75
|
+
|
|
76
|
+
def write_data_json(self, file_path):
|
|
77
|
+
logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
|
|
78
|
+
if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
|
|
79
|
+
data_to_write = load_json(file_path)
|
|
80
|
+
else:
|
|
81
|
+
self.init_json['data_path'] = self.dump_tensor_data_dir
|
|
82
|
+
data_to_write = self.init_json
|
|
83
|
+
data_to_write[Const.DATA].update(self.cache_data[Const.DATA])
|
|
84
|
+
save_json(file_path, data_to_write, indent=1)
|
|
85
|
+
self.cache_data[Const.DATA].clear()
|
|
86
|
+
|
|
87
|
+
def write_stack_info_json(self, file_path):
|
|
88
|
+
save_json(file_path, self.cache_stack, indent=1)
|
|
89
|
+
|
|
90
|
+
def write_construct_info_json(self, file_path):
|
|
91
|
+
save_json(file_path, self.cache_construct, indent=1)
|
|
92
|
+
|
|
93
|
+
def write_json(self):
|
|
94
|
+
self.write_data_json(self.dump_file_path)
|
|
95
|
+
self.write_stack_info_json(self.stack_file_path)
|
|
96
|
+
self.write_construct_info_json(self.construct_file_path)
|
msprobe/core/data_dump/scope.py
CHANGED
|
@@ -1,178 +1,178 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
|
-
from msprobe.core.common.exceptions import ScopeException
|
|
3
|
-
from msprobe.core.common.const import Const
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def build_scope(scope_class, scope=None, api_list=None):
|
|
7
|
-
if not scope and not api_list:
|
|
8
|
-
return None
|
|
9
|
-
if scope is None:
|
|
10
|
-
scope = []
|
|
11
|
-
if api_list is None:
|
|
12
|
-
api_list = []
|
|
13
|
-
if scope_class:
|
|
14
|
-
return scope_class(scope, api_list)
|
|
15
|
-
return build_range_scope_according_to_scope_name(scope, api_list)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def build_range_scope_according_to_scope_name(scope, api_list):
|
|
19
|
-
api_range_scope = APIRangeScope(scope, api_list)
|
|
20
|
-
module_range_scope = ModuleRangeScope(scope, api_list)
|
|
21
|
-
if not scope: # 如果没有scope参数则用哪类scope都一样
|
|
22
|
-
return api_range_scope
|
|
23
|
-
if api_range_scope.is_valid and module_range_scope.is_valid:
|
|
24
|
-
raise ScopeException(ScopeException.InvalidScope, f"scope={scope}.")
|
|
25
|
-
elif api_range_scope.is_valid:
|
|
26
|
-
return api_range_scope
|
|
27
|
-
elif module_range_scope.is_valid:
|
|
28
|
-
return module_range_scope
|
|
29
|
-
else:
|
|
30
|
-
raise ScopeException(ScopeException.InvalidScope, f"scope={scope}")
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class BaseScope(ABC):
|
|
34
|
-
Module_Type_Module = "Module"
|
|
35
|
-
Module_Type_API = "api"
|
|
36
|
-
|
|
37
|
-
def __init__(self, scope, api_list):
|
|
38
|
-
scope, api_list = self.rectify_args(scope, api_list)
|
|
39
|
-
self.scope = scope
|
|
40
|
-
self.api_list = api_list
|
|
41
|
-
|
|
42
|
-
@staticmethod
|
|
43
|
-
def rectify_args(scope, api_list):
|
|
44
|
-
if not isinstance(api_list, list):
|
|
45
|
-
raise ScopeException(ScopeException.InvalidApiStr,
|
|
46
|
-
f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
|
|
47
|
-
for api in api_list:
|
|
48
|
-
if not isinstance(api, str):
|
|
49
|
-
raise ScopeException(ScopeException.InvalidApiStr,
|
|
50
|
-
f"api_list中的元素须配置为字符串,实际类型为{type(api)}.")
|
|
51
|
-
if isinstance(scope, str):
|
|
52
|
-
scope = [scope]
|
|
53
|
-
return scope, api_list
|
|
54
|
-
if not isinstance(scope, list):
|
|
55
|
-
raise ScopeException(ScopeException.InvalidScope,
|
|
56
|
-
f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.")
|
|
57
|
-
for s in scope:
|
|
58
|
-
if not isinstance(s, str):
|
|
59
|
-
raise ScopeException(ScopeException.InvalidScope,
|
|
60
|
-
f"scope列表元素要求类型为字符串,实际类型为{type(s)}.")
|
|
61
|
-
return scope, api_list
|
|
62
|
-
|
|
63
|
-
@abstractmethod
|
|
64
|
-
def check(self, name):
|
|
65
|
-
pass
|
|
66
|
-
|
|
67
|
-
def check_api_list(self, api_name):
|
|
68
|
-
if not self.api_list:
|
|
69
|
-
return True
|
|
70
|
-
for api_str in self.api_list:
|
|
71
|
-
if api_str in api_name:
|
|
72
|
-
return True
|
|
73
|
-
return False
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
class ListScope(BaseScope):
|
|
77
|
-
@staticmethod
|
|
78
|
-
def rectify_args(scope, api_list):
|
|
79
|
-
if scope and api_list:
|
|
80
|
-
raise ScopeException(ScopeException.ArgConflict,
|
|
81
|
-
f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
|
|
82
|
-
return super(ListScope, ListScope).rectify_args(scope, api_list)
|
|
83
|
-
|
|
84
|
-
def check(self, module_name):
|
|
85
|
-
if not self.scope or module_name in self.scope:
|
|
86
|
-
return self.check_api_list(module_name)
|
|
87
|
-
return False
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
class RangeScope(BaseScope, ABC):
|
|
91
|
-
|
|
92
|
-
def __init__(self, *args):
|
|
93
|
-
super().__init__(*args)
|
|
94
|
-
self.in_scope = False
|
|
95
|
-
self.is_valid = self.check_scope_is_valid()
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
@staticmethod
|
|
99
|
-
def rectify_args(scope, api_list):
|
|
100
|
-
scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
|
|
101
|
-
if isinstance(scope, list):
|
|
102
|
-
if len(scope) == 1:
|
|
103
|
-
scope.append(scope[0])
|
|
104
|
-
elif len(scope) > 2:
|
|
105
|
-
raise ScopeException(ScopeException.InvalidScope,
|
|
106
|
-
f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.")
|
|
107
|
-
|
|
108
|
-
return scope, api_list
|
|
109
|
-
|
|
110
|
-
@abstractmethod
|
|
111
|
-
def check_scope_is_valid(self):
|
|
112
|
-
pass
|
|
113
|
-
|
|
114
|
-
def begin_module(self, module_name):
|
|
115
|
-
pass
|
|
116
|
-
|
|
117
|
-
def end_module(self, module_name):
|
|
118
|
-
pass
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
class APIRangeScope(RangeScope):
|
|
122
|
-
def check_scope_is_valid(self):
|
|
123
|
-
if not self.scope:
|
|
124
|
-
return True
|
|
125
|
-
scope_start_type = self.scope[0].split(Const.SEP)[0]
|
|
126
|
-
if scope_start_type == BaseScope.Module_Type_Module:
|
|
127
|
-
return False
|
|
128
|
-
scope_stop_type = self.scope[1].split(Const.SEP)[0]
|
|
129
|
-
if scope_stop_type == BaseScope.Module_Type_Module:
|
|
130
|
-
return False
|
|
131
|
-
return True
|
|
132
|
-
|
|
133
|
-
def check(self, api_name):
|
|
134
|
-
if self.scope and api_name == self.scope[0]:
|
|
135
|
-
self.in_scope = True
|
|
136
|
-
|
|
137
|
-
if not self.scope or self.in_scope:
|
|
138
|
-
result = self.check_api_list(api_name)
|
|
139
|
-
else:
|
|
140
|
-
result = False
|
|
141
|
-
|
|
142
|
-
if self.scope and api_name == self.scope[1]:
|
|
143
|
-
self.in_scope = False
|
|
144
|
-
return result
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
class ModuleRangeScope(RangeScope):
|
|
148
|
-
"""
|
|
149
|
-
模块与api不同的是,模块内部还有子结构需要dump,
|
|
150
|
-
需要用pre_hook和full_backward_hook来精确控制module的开始和结束,
|
|
151
|
-
在这些hook触发时调用begin_module和end_module做区间控制
|
|
152
|
-
"""
|
|
153
|
-
def check_scope_is_valid(self):
|
|
154
|
-
if not self.scope:
|
|
155
|
-
return True
|
|
156
|
-
scope_start_type = self.scope[0].split(Const.SEP)[0]
|
|
157
|
-
scope_stop_type = self.scope[1].split(Const.SEP)[0]
|
|
158
|
-
if scope_start_type == BaseScope.Module_Type_Module and \
|
|
159
|
-
scope_stop_type == BaseScope.Module_Type_Module:
|
|
160
|
-
return True
|
|
161
|
-
return False
|
|
162
|
-
|
|
163
|
-
def begin_module(self, module_name):
|
|
164
|
-
if not self.scope:
|
|
165
|
-
return
|
|
166
|
-
if module_name == self.scope[0]:
|
|
167
|
-
self.in_scope = True
|
|
168
|
-
|
|
169
|
-
def end_module(self, module_name):
|
|
170
|
-
if not self.scope:
|
|
171
|
-
return
|
|
172
|
-
if module_name == self.scope[1]:
|
|
173
|
-
self.in_scope = False
|
|
174
|
-
|
|
175
|
-
def check(self, module_name):
|
|
176
|
-
if not self.scope or self.in_scope:
|
|
177
|
-
return self.check_api_list(module_name)
|
|
178
|
-
return False
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from msprobe.core.common.exceptions import ScopeException
|
|
3
|
+
from msprobe.core.common.const import Const
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def build_scope(scope_class, scope=None, api_list=None):
|
|
7
|
+
if not scope and not api_list:
|
|
8
|
+
return None
|
|
9
|
+
if scope is None:
|
|
10
|
+
scope = []
|
|
11
|
+
if api_list is None:
|
|
12
|
+
api_list = []
|
|
13
|
+
if scope_class:
|
|
14
|
+
return scope_class(scope, api_list)
|
|
15
|
+
return build_range_scope_according_to_scope_name(scope, api_list)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def build_range_scope_according_to_scope_name(scope, api_list):
|
|
19
|
+
api_range_scope = APIRangeScope(scope, api_list)
|
|
20
|
+
module_range_scope = ModuleRangeScope(scope, api_list)
|
|
21
|
+
if not scope: # 如果没有scope参数则用哪类scope都一样
|
|
22
|
+
return api_range_scope
|
|
23
|
+
if api_range_scope.is_valid and module_range_scope.is_valid:
|
|
24
|
+
raise ScopeException(ScopeException.InvalidScope, f"scope={scope}.")
|
|
25
|
+
elif api_range_scope.is_valid:
|
|
26
|
+
return api_range_scope
|
|
27
|
+
elif module_range_scope.is_valid:
|
|
28
|
+
return module_range_scope
|
|
29
|
+
else:
|
|
30
|
+
raise ScopeException(ScopeException.InvalidScope, f"scope={scope}")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class BaseScope(ABC):
|
|
34
|
+
Module_Type_Module = "Module"
|
|
35
|
+
Module_Type_API = "api"
|
|
36
|
+
|
|
37
|
+
def __init__(self, scope, api_list):
|
|
38
|
+
scope, api_list = self.rectify_args(scope, api_list)
|
|
39
|
+
self.scope = scope
|
|
40
|
+
self.api_list = api_list
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def rectify_args(scope, api_list):
|
|
44
|
+
if not isinstance(api_list, list):
|
|
45
|
+
raise ScopeException(ScopeException.InvalidApiStr,
|
|
46
|
+
f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
|
|
47
|
+
for api in api_list:
|
|
48
|
+
if not isinstance(api, str):
|
|
49
|
+
raise ScopeException(ScopeException.InvalidApiStr,
|
|
50
|
+
f"api_list中的元素须配置为字符串,实际类型为{type(api)}.")
|
|
51
|
+
if isinstance(scope, str):
|
|
52
|
+
scope = [scope]
|
|
53
|
+
return scope, api_list
|
|
54
|
+
if not isinstance(scope, list):
|
|
55
|
+
raise ScopeException(ScopeException.InvalidScope,
|
|
56
|
+
f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.")
|
|
57
|
+
for s in scope:
|
|
58
|
+
if not isinstance(s, str):
|
|
59
|
+
raise ScopeException(ScopeException.InvalidScope,
|
|
60
|
+
f"scope列表元素要求类型为字符串,实际类型为{type(s)}.")
|
|
61
|
+
return scope, api_list
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
def check(self, name):
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
def check_api_list(self, api_name):
|
|
68
|
+
if not self.api_list:
|
|
69
|
+
return True
|
|
70
|
+
for api_str in self.api_list:
|
|
71
|
+
if api_str in api_name:
|
|
72
|
+
return True
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class ListScope(BaseScope):
|
|
77
|
+
@staticmethod
|
|
78
|
+
def rectify_args(scope, api_list):
|
|
79
|
+
if scope and api_list:
|
|
80
|
+
raise ScopeException(ScopeException.ArgConflict,
|
|
81
|
+
f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
|
|
82
|
+
return super(ListScope, ListScope).rectify_args(scope, api_list)
|
|
83
|
+
|
|
84
|
+
def check(self, module_name):
|
|
85
|
+
if not self.scope or module_name in self.scope:
|
|
86
|
+
return self.check_api_list(module_name)
|
|
87
|
+
return False
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class RangeScope(BaseScope, ABC):
|
|
91
|
+
|
|
92
|
+
def __init__(self, *args):
|
|
93
|
+
super().__init__(*args)
|
|
94
|
+
self.in_scope = False
|
|
95
|
+
self.is_valid = self.check_scope_is_valid()
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def rectify_args(scope, api_list):
|
|
100
|
+
scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
|
|
101
|
+
if isinstance(scope, list):
|
|
102
|
+
if len(scope) == 1:
|
|
103
|
+
scope.append(scope[0])
|
|
104
|
+
elif len(scope) > 2:
|
|
105
|
+
raise ScopeException(ScopeException.InvalidScope,
|
|
106
|
+
f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.")
|
|
107
|
+
|
|
108
|
+
return scope, api_list
|
|
109
|
+
|
|
110
|
+
@abstractmethod
|
|
111
|
+
def check_scope_is_valid(self):
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
def begin_module(self, module_name):
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
def end_module(self, module_name):
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class APIRangeScope(RangeScope):
|
|
122
|
+
def check_scope_is_valid(self):
|
|
123
|
+
if not self.scope:
|
|
124
|
+
return True
|
|
125
|
+
scope_start_type = self.scope[0].split(Const.SEP)[0]
|
|
126
|
+
if scope_start_type == BaseScope.Module_Type_Module:
|
|
127
|
+
return False
|
|
128
|
+
scope_stop_type = self.scope[1].split(Const.SEP)[0]
|
|
129
|
+
if scope_stop_type == BaseScope.Module_Type_Module:
|
|
130
|
+
return False
|
|
131
|
+
return True
|
|
132
|
+
|
|
133
|
+
def check(self, api_name):
|
|
134
|
+
if self.scope and api_name == self.scope[0]:
|
|
135
|
+
self.in_scope = True
|
|
136
|
+
|
|
137
|
+
if not self.scope or self.in_scope:
|
|
138
|
+
result = self.check_api_list(api_name)
|
|
139
|
+
else:
|
|
140
|
+
result = False
|
|
141
|
+
|
|
142
|
+
if self.scope and api_name == self.scope[1]:
|
|
143
|
+
self.in_scope = False
|
|
144
|
+
return result
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class ModuleRangeScope(RangeScope):
|
|
148
|
+
"""
|
|
149
|
+
模块与api不同的是,模块内部还有子结构需要dump,
|
|
150
|
+
需要用pre_hook和full_backward_hook来精确控制module的开始和结束,
|
|
151
|
+
在这些hook触发时调用begin_module和end_module做区间控制
|
|
152
|
+
"""
|
|
153
|
+
def check_scope_is_valid(self):
|
|
154
|
+
if not self.scope:
|
|
155
|
+
return True
|
|
156
|
+
scope_start_type = self.scope[0].split(Const.SEP)[0]
|
|
157
|
+
scope_stop_type = self.scope[1].split(Const.SEP)[0]
|
|
158
|
+
if scope_start_type == BaseScope.Module_Type_Module and \
|
|
159
|
+
scope_stop_type == BaseScope.Module_Type_Module:
|
|
160
|
+
return True
|
|
161
|
+
return False
|
|
162
|
+
|
|
163
|
+
def begin_module(self, module_name):
|
|
164
|
+
if not self.scope:
|
|
165
|
+
return
|
|
166
|
+
if module_name == self.scope[0]:
|
|
167
|
+
self.in_scope = True
|
|
168
|
+
|
|
169
|
+
def end_module(self, module_name):
|
|
170
|
+
if not self.scope:
|
|
171
|
+
return
|
|
172
|
+
if module_name == self.scope[1]:
|
|
173
|
+
self.in_scope = False
|
|
174
|
+
|
|
175
|
+
def check(self, module_name):
|
|
176
|
+
if not self.scope or self.in_scope:
|
|
177
|
+
return self.check_api_list(module_name)
|
|
178
|
+
return False
|
|
File without changes
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
|
|
2
|
+
class GradConst:
|
|
3
|
+
|
|
4
|
+
FRAMEWORKS = {"PyTorch", "MindSpore"}
|
|
5
|
+
PYTORCH = "PyTorch"
|
|
6
|
+
MindSpore = "MindSpore"
|
|
7
|
+
|
|
8
|
+
GRAD_FILE_SUFFIX = {"npy", "pt"}
|
|
9
|
+
NPY_SUFFIX = "npy"
|
|
10
|
+
PT_SUFFIX = "pt"
|
|
11
|
+
|
|
12
|
+
# for callback
|
|
13
|
+
CURRENT_STEP = "current_step"
|
|
14
|
+
|
|
15
|
+
PARAM_LIST = "param_list"
|
|
16
|
+
RANK = "rank"
|
|
17
|
+
STEP = "step"
|
|
18
|
+
BOUNDS = "bounds"
|
|
19
|
+
OUTPUT_PATH = "output_path"
|
|
20
|
+
|
|
21
|
+
# level const
|
|
22
|
+
LEVEL = "level"
|
|
23
|
+
LEVEL0 = "L0"
|
|
24
|
+
LEVEL1 = "L1"
|
|
25
|
+
LEVEL2 = "L2"
|
|
26
|
+
SUPPORTED_LEVEL = {"L0", "L1", "L2"}
|
|
27
|
+
|
|
28
|
+
# numpy coding
|
|
29
|
+
STEP_IDX = 0
|
|
30
|
+
SHAPE_DIM_IDX = 4
|
|
31
|
+
MAX_SIZE = 10 * 1024 * 1024 * 1024
|
|
32
|
+
|
|
33
|
+
# direction suffix
|
|
34
|
+
DIR_SUFFIX = "dir.npy"
|
|
35
|
+
|
|
36
|
+
# file safty
|
|
37
|
+
DATA_DIR_AUTHORITY = 0o750
|
|
38
|
+
DATA_FILE_AUTHORITY = 0o640
|
|
39
|
+
DIRECTORY_LENGTH = 4096
|
|
40
|
+
FILE_NAME_LENGTH = 255
|
|
41
|
+
FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$"
|
|
42
|
+
PARAM_VALID_PATTERN = r"^[a-zA-Z0-9_.]+$"
|
|
43
|
+
DIR = "dir"
|
|
44
|
+
FILE = "file"
|
|
45
|
+
|
|
46
|
+
STEP_FINISH = "step_finish"
|
|
47
|
+
|
|
48
|
+
SUMMARY = "summary"
|
|
49
|
+
|
|
50
|
+
# csv header entry
|
|
51
|
+
MD5 = "MD5"
|
|
52
|
+
DISTRIBUTION = "distribution"
|
|
53
|
+
SHAPE = "shape"
|
|
54
|
+
MAX = "max"
|
|
55
|
+
MIN = "min"
|
|
56
|
+
NORM = "norm"
|
|
57
|
+
|
|
58
|
+
level_adp = {
|
|
59
|
+
"L0": {
|
|
60
|
+
"header": [GradConst.MD5, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
|
|
61
|
+
"have_grad_direction": False
|
|
62
|
+
},
|
|
63
|
+
"L1": {
|
|
64
|
+
"header": [GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
|
|
65
|
+
"have_grad_direction": True
|
|
66
|
+
},
|
|
67
|
+
"L2": {
|
|
68
|
+
"header": [GradConst.DISTRIBUTION, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
|
|
69
|
+
"have_grad_direction": True
|
|
70
|
+
},
|
|
71
|
+
}
|