mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
- mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +101 -182
- msprobe/__init__.py +1 -0
- msprobe/{config/config.json → config.json} +49 -27
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +124 -124
- msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
- msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +341 -241
- msprobe/core/common/exceptions.py +100 -88
- msprobe/core/common/{file_check.py → file_utils.py} +478 -265
- msprobe/core/common/log.py +76 -55
- msprobe/core/common/utils.py +385 -516
- msprobe/core/common_config.py +85 -58
- msprobe/core/compare/acc_compare.py +300 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +223 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
- msprobe/core/compare/utils.py +430 -0
- msprobe/core/data_dump/data_collector.py +154 -140
- msprobe/core/data_dump/data_processor/base.py +314 -245
- msprobe/core/data_dump/data_processor/factory.py +59 -61
- msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
- msprobe/core/data_dump/json_writer.py +96 -116
- msprobe/core/data_dump/scope.py +178 -178
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +171 -0
- msprobe/core/grad_probe/utils.py +64 -0
- msprobe/docs/01.installation.md +89 -0
- msprobe/docs/02.config_introduction.md +165 -0
- msprobe/docs/03.config_examples.md +247 -0
- msprobe/docs/04.acl_config_examples.md +76 -0
- msprobe/docs/05.data_dump_PyTorch.md +198 -0
- msprobe/docs/06.data_dump_MindSpore.md +243 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
- msprobe/docs/17.grad_probe.md +207 -0
- msprobe/docs/FAQ_PyTorch.md +177 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/grad_probe_image-1.png +0 -0
- msprobe/docs/img/grad_probe_image-2.png +0 -0
- msprobe/docs/img/grad_probe_image-3.png +0 -0
- msprobe/docs/img/grad_probe_image-4.png +0 -0
- msprobe/docs/img/grad_probe_image.png +0 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
- msprobe/mindspore/api_accuracy_checker/main.py +9 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +106 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +81 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +219 -0
- msprobe/mindspore/compare/ms_graph_compare.py +348 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +66 -51
- msprobe/mindspore/debugger/precision_debugger.py +126 -32
- msprobe/mindspore/dump/dump_tool_factory.py +35 -38
- msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
- msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
- msprobe/mindspore/dump/jit_dump.py +72 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +90 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +94 -0
- msprobe/mindspore/grad_probe/utils.py +30 -0
- msprobe/mindspore/ms_config.py +128 -78
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +378 -0
- msprobe/mindspore/task_handler_factory.py +24 -21
- msprobe/msprobe.py +105 -67
- msprobe/pytorch/__init__.py +4 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
- msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/__init__.py +2 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +20 -31
- msprobe/pytorch/common/parse_json.py +39 -37
- msprobe/pytorch/common/utils.py +305 -224
- msprobe/pytorch/compare/distributed_compare.py +66 -111
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +34 -36
- msprobe/pytorch/compare/pt_compare.py +50 -0
- msprobe/pytorch/debugger/debugger_config.py +95 -86
- msprobe/pytorch/debugger/precision_debugger.py +125 -95
- msprobe/pytorch/free_benchmark/__init__.py +8 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -67
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +37 -37
- msprobe/pytorch/free_benchmark/common/params.py +129 -129
- msprobe/pytorch/free_benchmark/common/utils.py +102 -98
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
- msprobe/pytorch/free_benchmark/main.py +105 -102
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
- msprobe/pytorch/function_factory.py +76 -0
- msprobe/pytorch/functional/dump_module.py +39 -39
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -161
- msprobe/pytorch/hook_module/hook_module.py +120 -109
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
- msprobe/pytorch/hook_module/utils.py +30 -29
- msprobe/pytorch/hook_module/wrap_aten.py +110 -100
- msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
- msprobe/pytorch/hook_module/wrap_functional.py +105 -108
- msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
- msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
- msprobe/pytorch/hook_module/wrap_torch.py +86 -88
- msprobe/pytorch/hook_module/wrap_vf.py +62 -64
- msprobe/pytorch/module_processer.py +138 -98
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +236 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -273
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
- msprobe/pytorch/online_dispatch/utils.py +130 -187
- msprobe/pytorch/parse.py +4 -4
- msprobe/pytorch/parse_tool/cli.py +32 -32
- msprobe/pytorch/parse_tool/lib/compare.py +260 -259
- msprobe/pytorch/parse_tool/lib/config.py +52 -51
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
- msprobe/pytorch/parse_tool/lib/utils.py +316 -367
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
- msprobe/pytorch/pt_config.py +188 -93
- msprobe/pytorch/service.py +246 -167
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/config/README.md +0 -397
- msprobe/mindspore/doc/dump.md +0 -65
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/dump.md +0 -207
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -1,7 +1,70 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
from msprobe.core.common.const import FileCheckConst
|
|
5
|
+
from msprobe.core.common.file_utils import FileChecker
|
|
6
|
+
from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
|
|
7
|
+
from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
|
|
8
|
+
from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
|
|
9
|
+
from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
|
|
10
|
+
from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
|
|
11
|
+
|
|
12
|
+
hf_32_standard_api = ["conv1d", "conv2d"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Backward_Message:
|
|
16
|
+
MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
|
|
17
|
+
UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, skip backward."
|
|
18
|
+
NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class UtDataInfo:
|
|
22
|
+
def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list,
|
|
23
|
+
backward_message, rank=0):
|
|
24
|
+
self.bench_grad = bench_grad
|
|
25
|
+
self.device_grad = device_grad
|
|
26
|
+
self.device_output = device_output
|
|
27
|
+
self.bench_output = bench_output
|
|
28
|
+
self.grad_in = grad_in
|
|
29
|
+
self.in_fwd_data_list = in_fwd_data_list
|
|
30
|
+
self.backward_message = backward_message
|
|
31
|
+
self.rank = rank
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_validated_result_csv_path(result_csv_path, mode):
|
|
35
|
+
if mode not in ['result', 'detail']:
|
|
36
|
+
raise ValueError("The csv mode must be result or detail")
|
|
37
|
+
result_csv_path_checker = FileChecker(result_csv_path, FileCheckConst.FILE, ability=FileCheckConst.READ_WRITE_ABLE,
|
|
38
|
+
file_type=FileCheckConst.CSV_SUFFIX)
|
|
39
|
+
validated_result_csv_path = result_csv_path_checker.common_check()
|
|
40
|
+
if mode == 'result':
|
|
41
|
+
result_csv_name = os.path.basename(validated_result_csv_path)
|
|
42
|
+
pattern = r"^accuracy_checking_result_\d{14}\.csv$"
|
|
43
|
+
if not re.match(pattern, result_csv_name):
|
|
44
|
+
raise ValueError("When continue run ut, please do not modify the result csv name.")
|
|
45
|
+
return validated_result_csv_path
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_validated_details_csv_path(validated_result_csv_path):
|
|
49
|
+
result_csv_name = os.path.basename(validated_result_csv_path)
|
|
50
|
+
details_csv_name = result_csv_name.replace('result', 'details')
|
|
51
|
+
details_csv_path = os.path.join(os.path.dirname(validated_result_csv_path), details_csv_name)
|
|
52
|
+
details_csv_path_checker = FileChecker(details_csv_path, FileCheckConst.FILE,
|
|
53
|
+
ability=FileCheckConst.READ_WRITE_ABLE, file_type=FileCheckConst.CSV_SUFFIX)
|
|
54
|
+
validated_details_csv_path = details_csv_path_checker.common_check()
|
|
55
|
+
return validated_details_csv_path
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def exec_api(api_type, api_name, device, args, kwargs):
|
|
59
|
+
if api_type == "Functional":
|
|
60
|
+
torch_api = FunctionalOPTemplate(api_name, str, False)
|
|
61
|
+
if api_type == "Tensor":
|
|
62
|
+
torch_api = TensorOPTemplate(api_name, str, False)
|
|
63
|
+
if api_type == "Torch":
|
|
64
|
+
torch_api = TorchOPTemplate(api_name, str, False)
|
|
65
|
+
if api_type == "Aten":
|
|
66
|
+
torch_api = AtenOPTemplate(api_name, None, False)
|
|
67
|
+
if api_type == "NPU":
|
|
68
|
+
torch_api = NpuOPTemplate(api_name, None, False, device)
|
|
69
|
+
out = torch_api.forward(*args, **kwargs)
|
|
70
|
+
return out
|
|
File without changes
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
import glob
|
|
2
|
+
import os.path
|
|
3
|
+
import time
|
|
4
|
+
import re
|
|
5
|
+
from multiprocessing import Queue
|
|
6
|
+
from typing import Optional, Union, Dict, Any
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
12
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
|
|
13
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
|
|
14
|
+
from msprobe.pytorch.common.utils import logger
|
|
15
|
+
from msprobe.core.common.file_utils import remove_path
|
|
16
|
+
from msprobe.pytorch.common.utils import save_api_data, load_api_data, save_pt, load_pt
|
|
17
|
+
|
|
18
|
+
BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class ATTLConfig:
|
|
23
|
+
is_benchmark_device: bool
|
|
24
|
+
connect_ip: str
|
|
25
|
+
connect_port: int
|
|
26
|
+
# storage_config
|
|
27
|
+
nfs_path: str = None
|
|
28
|
+
tls_path: str = None
|
|
29
|
+
check_sum: bool = True
|
|
30
|
+
queue_size: int = 50
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ATTL:
|
|
34
|
+
def __init__(self, session_id: str, session_config: ATTLConfig, need_dump=True) -> None:
|
|
35
|
+
self.session_id = session_id
|
|
36
|
+
self.session_config = session_config
|
|
37
|
+
self.logger = logger
|
|
38
|
+
self.socket_manager = None
|
|
39
|
+
self.data_queue = Queue(maxsize=50)
|
|
40
|
+
self.dequeue_list = []
|
|
41
|
+
self.message_end = False
|
|
42
|
+
self.kill_progress = False
|
|
43
|
+
self.check_attl_config()
|
|
44
|
+
if self.session_config.nfs_path:
|
|
45
|
+
self.nfs_path = self.session_config.nfs_path
|
|
46
|
+
elif self.session_config.is_benchmark_device:
|
|
47
|
+
|
|
48
|
+
self.socket_manager = TCPServer(self.session_config.connect_port,
|
|
49
|
+
self.data_queue,
|
|
50
|
+
self.session_config.check_sum,
|
|
51
|
+
self.session_config.tls_path)
|
|
52
|
+
self.socket_manager.start()
|
|
53
|
+
elif need_dump:
|
|
54
|
+
self.socket_manager = TCPClient(self.session_config.connect_ip,
|
|
55
|
+
self.session_config.connect_port,
|
|
56
|
+
self.session_config.check_sum,
|
|
57
|
+
self.session_config.tls_path)
|
|
58
|
+
self.socket_manager.start()
|
|
59
|
+
|
|
60
|
+
def check_attl_config(self):
|
|
61
|
+
if self.session_config.nfs_path:
|
|
62
|
+
if os.path.exists(self.session_config.nfs_path):
|
|
63
|
+
return
|
|
64
|
+
else:
|
|
65
|
+
raise Exception(f"nfs path {self.session_config.nfs_path} doesn't exists.")
|
|
66
|
+
ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
|
|
67
|
+
if not re.match(ipv4_pattern, self.session_config.connect_ip):
|
|
68
|
+
raise Exception(f"host {self.session_config.connect_ip} is invalid.")
|
|
69
|
+
if not (0 < self.session_config.connect_port <= 65535):
|
|
70
|
+
raise Exception(f"port {self.session_config.connect_port} is invalid.")
|
|
71
|
+
|
|
72
|
+
def stop_serve(self):
|
|
73
|
+
if isinstance(self.socket_manager, TCPServer):
|
|
74
|
+
self.socket_manager.stop()
|
|
75
|
+
|
|
76
|
+
def send(self, buffer: BufferType) -> None:
|
|
77
|
+
"""
|
|
78
|
+
npu major in 'send' (client)
|
|
79
|
+
"""
|
|
80
|
+
# know receiver receive and go next
|
|
81
|
+
if isinstance(buffer, ApiData):
|
|
82
|
+
buffer = move2target_device(buffer, torch.device('cpu'))
|
|
83
|
+
|
|
84
|
+
if 'device' in buffer.kwargs:
|
|
85
|
+
buffer.kwargs.pop('device')
|
|
86
|
+
rank = buffer.rank if hasattr(buffer, "rank") and buffer.rank is not None else 0
|
|
87
|
+
step = buffer.step if hasattr(buffer, "step") else 0
|
|
88
|
+
try:
|
|
89
|
+
io_buff = save_api_data(buffer)
|
|
90
|
+
except Exception as e:
|
|
91
|
+
self.logger.info(f"{buffer.name} can not be saved, skip: {e}")
|
|
92
|
+
return
|
|
93
|
+
data = io_buff.getvalue()
|
|
94
|
+
self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
|
|
95
|
+
|
|
96
|
+
def recv(self, timeout_ms=0) -> Optional[BufferType]:
|
|
97
|
+
buffer = None
|
|
98
|
+
while buffer is None:
|
|
99
|
+
if timeout_ms > 0:
|
|
100
|
+
time.sleep(timeout_ms / 1000.0)
|
|
101
|
+
if buffer is None and not self.data_queue.empty():
|
|
102
|
+
buffer = self.data_queue.get()
|
|
103
|
+
break
|
|
104
|
+
if buffer is None and timeout_ms > 0: # timeout is the only case we give up and return None
|
|
105
|
+
break
|
|
106
|
+
if self.message_end and self.data_queue.empty():
|
|
107
|
+
buffer = b"KILL_CONFIRM"
|
|
108
|
+
self.kill_progress = True
|
|
109
|
+
break
|
|
110
|
+
time.sleep(0.1) # waiting outside the lock before next attempt
|
|
111
|
+
if buffer is None:
|
|
112
|
+
# this is a result of a timeout
|
|
113
|
+
self.logger.info(f"RECEIVE API DATA TIMED OUT")
|
|
114
|
+
else:
|
|
115
|
+
if buffer == b"STOP_":
|
|
116
|
+
return "STOP_"
|
|
117
|
+
if buffer == b"KILL_":
|
|
118
|
+
self.message_end = True
|
|
119
|
+
return "STOP_"
|
|
120
|
+
if buffer == b"KILL_CONFIRM":
|
|
121
|
+
self.kill_progress = True
|
|
122
|
+
return "KILL_"
|
|
123
|
+
try:
|
|
124
|
+
buffer = load_api_data(buffer)
|
|
125
|
+
except Exception as e:
|
|
126
|
+
self.logger.warning("there is something error. please check it. %s", e)
|
|
127
|
+
if isinstance(buffer, bytes):
|
|
128
|
+
return None
|
|
129
|
+
if isinstance(buffer, str):
|
|
130
|
+
return buffer
|
|
131
|
+
|
|
132
|
+
return buffer
|
|
133
|
+
|
|
134
|
+
def upload(self, buffer: BufferType):
|
|
135
|
+
if isinstance(buffer, ApiData):
|
|
136
|
+
buffer = move2target_device(buffer, torch.device('cpu'))
|
|
137
|
+
file_path = os.path.join(self.session_config.nfs_path, buffer.name + ".pt")
|
|
138
|
+
else:
|
|
139
|
+
file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
save_pt(buffer, file_path)
|
|
143
|
+
except Exception as e:
|
|
144
|
+
self.logger.warning("there is something error in save_pt. please check it. %s", e)
|
|
145
|
+
|
|
146
|
+
def download(self):
|
|
147
|
+
buffer = None
|
|
148
|
+
cur_file = None
|
|
149
|
+
for file_type in ("start*", "*.pt", "end*"):
|
|
150
|
+
pattern = os.path.join(self.nfs_path, file_type)
|
|
151
|
+
files = glob.glob(pattern)
|
|
152
|
+
if len(files) > 0:
|
|
153
|
+
cur_file = files[0]
|
|
154
|
+
break
|
|
155
|
+
|
|
156
|
+
if cur_file is not None:
|
|
157
|
+
try:
|
|
158
|
+
buffer = load_pt(cur_file)
|
|
159
|
+
except Exception as e:
|
|
160
|
+
self.logger.warning("there is something error. please check it. %s", e)
|
|
161
|
+
remove_path(cur_file)
|
|
162
|
+
return buffer
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def move2device_exec(obj, device):
|
|
166
|
+
if isinstance(obj, (tuple, list)):
|
|
167
|
+
data_list = [move2device_exec(val, device) for val in obj]
|
|
168
|
+
return data_list if isinstance(obj, list) else tuple(data_list)
|
|
169
|
+
if isinstance(obj, dict):
|
|
170
|
+
return {key: move2device_exec(val, device) for key, val in obj.items()}
|
|
171
|
+
elif isinstance(obj, torch.Tensor):
|
|
172
|
+
obj = obj.detach()
|
|
173
|
+
if obj.device.type != device:
|
|
174
|
+
obj = obj.to(device)
|
|
175
|
+
return obj
|
|
176
|
+
elif "return_types" in str(type(obj)):
|
|
177
|
+
return move2device_exec(tuple(obj), device)
|
|
178
|
+
elif isinstance(obj, torch._C.device):
|
|
179
|
+
return torch.device(device)
|
|
180
|
+
else:
|
|
181
|
+
return obj
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def move2target_device(buffer: ApiData, target_device):
|
|
185
|
+
# handle args
|
|
186
|
+
new_args = move2device_exec(buffer.args, target_device)
|
|
187
|
+
|
|
188
|
+
# handle kwargs
|
|
189
|
+
new_kwargs = move2device_exec(buffer.kwargs, target_device)
|
|
190
|
+
|
|
191
|
+
# handle result
|
|
192
|
+
new_results = move2device_exec(buffer.result, target_device)
|
|
193
|
+
|
|
194
|
+
if target_device == torch.device('cpu') or target_device == "cpu":
|
|
195
|
+
return ApiData(buffer.name, tuple(new_args), new_kwargs, new_results, buffer.step, buffer.rank)
|
|
196
|
+
else:
|
|
197
|
+
return ApiData(buffer.name, tuple(new_args), new_kwargs, buffer.result, buffer.step, buffer.rank)
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import io
|
|
3
|
+
import struct
|
|
4
|
+
import time
|
|
5
|
+
import os
|
|
6
|
+
import signal
|
|
7
|
+
import sys
|
|
8
|
+
from queue import Queue
|
|
9
|
+
from threading import Thread
|
|
10
|
+
from typing import Union
|
|
11
|
+
|
|
12
|
+
from twisted.internet import reactor, protocol, endpoints
|
|
13
|
+
from twisted.protocols.basic import FileSender
|
|
14
|
+
|
|
15
|
+
from msprobe.pytorch.common.utils import logger
|
|
16
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.ssl_config import cipher_list
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TCPDataItem:
|
|
20
|
+
def __init__(self, data,
|
|
21
|
+
sequence_number: int,
|
|
22
|
+
rank: int = 0,
|
|
23
|
+
step: int = 0):
|
|
24
|
+
self.raw_data = data
|
|
25
|
+
self.sequence_number = sequence_number
|
|
26
|
+
self.rank = rank
|
|
27
|
+
self.step = step
|
|
28
|
+
self.retry_times = 0
|
|
29
|
+
self.pending_time = 0
|
|
30
|
+
self.busy_time = 0
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TCPClient:
|
|
34
|
+
MAX_SENDING_QUEUE_SIZE = 20
|
|
35
|
+
ACK_SUCCESS = b"OK___"
|
|
36
|
+
ACK_ERROR = b"ERROR"
|
|
37
|
+
ACK_BUSY = b"BUSY_"
|
|
38
|
+
ACK_STOP = b"STOP_"
|
|
39
|
+
ACK_STOP_CONFIRM = b"OVER_"
|
|
40
|
+
ACK_KILL_PROCESS = b"KILL_"
|
|
41
|
+
|
|
42
|
+
QUEUE_PENDING_TIME = 600 # 队列10分钟都处于阻塞状态,则终止sending进程
|
|
43
|
+
RESEND_RETRY_TIMES = 2 # 最大重传数
|
|
44
|
+
RESEND_TIMER_TIME = 5 # 接收ACK超时定时器
|
|
45
|
+
RESEND_PENDING_TIME = 60 # 连续pending时间超过1分钟则放弃该数据
|
|
46
|
+
|
|
47
|
+
def __init__(self, host="localhost", port=8000, check_sum=False, tls_path=None):
|
|
48
|
+
self.send_queue = Queue(self.MAX_SENDING_QUEUE_SIZE)
|
|
49
|
+
self.resend_dict = dict()
|
|
50
|
+
self.host = host
|
|
51
|
+
self.port = port
|
|
52
|
+
self.tls_path = tls_path
|
|
53
|
+
self.factory = None
|
|
54
|
+
self.sequence_number = 0
|
|
55
|
+
self.signal_exit = False
|
|
56
|
+
self.tcp_manager = ClientProtocol(ack_queue_size=100,
|
|
57
|
+
chunk_size=655360,
|
|
58
|
+
check_sum=check_sum)
|
|
59
|
+
self.send_thread = Thread(target=self._sending_queue_data)
|
|
60
|
+
self.send_thread.setDaemon(True)
|
|
61
|
+
self.send_thread.start()
|
|
62
|
+
self.destroy_thread = Thread(target=self._destroy_queue_data)
|
|
63
|
+
self.destroy_thread.setDaemon(True)
|
|
64
|
+
self.destroy_thread.start()
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def run_reactor():
|
|
68
|
+
reactor.run(installSignalHandlers=False)
|
|
69
|
+
|
|
70
|
+
def start(self):
|
|
71
|
+
def conn_callback(cur_protocol):
|
|
72
|
+
if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host:
|
|
73
|
+
logger.debug(f"Process: {os.getpid()} connects to server successfully.")
|
|
74
|
+
else:
|
|
75
|
+
logger.warning(f"Process: {os.getpid()} fails to connect to server. ")
|
|
76
|
+
raise ConnectionError(f"Failed to connect to {self.host}.")
|
|
77
|
+
|
|
78
|
+
def conn_err_callback(failure):
|
|
79
|
+
self.signal_exit = True
|
|
80
|
+
time.sleep(1)
|
|
81
|
+
reactor.stop()
|
|
82
|
+
logger.error(f"Failed to connected {self.host} {self.port}. Reason is {failure.getErrorMessage()}")
|
|
83
|
+
os.kill(os.getpid(), signal.SIGKILL)
|
|
84
|
+
os.kill(os.getppid(), signal.SIGKILL)
|
|
85
|
+
|
|
86
|
+
def cur_protocol():
|
|
87
|
+
return self.tcp_manager
|
|
88
|
+
|
|
89
|
+
self.factory = MessageClientFactory()
|
|
90
|
+
self.factory.protocol = cur_protocol
|
|
91
|
+
if self.tls_path:
|
|
92
|
+
from OpenSSL import SSL
|
|
93
|
+
from twisted.internet import ssl
|
|
94
|
+
client_key = os.path.join(self.tls_path, "client.key")
|
|
95
|
+
client_crt = os.path.join(self.tls_path, "client.crt")
|
|
96
|
+
client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt, SSL.TLSv1_2_METHOD)
|
|
97
|
+
client_context_ = client_context_factory.getContext()
|
|
98
|
+
client_context_.set_cipher_list(cipher_list)
|
|
99
|
+
client_context_.set_options(SSL.OP_NO_RENEGOTIATION)
|
|
100
|
+
endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory)
|
|
101
|
+
else:
|
|
102
|
+
endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port)
|
|
103
|
+
d = endpoint.connect(self.factory)
|
|
104
|
+
d.addCallback(conn_callback)
|
|
105
|
+
d.addErrback(conn_err_callback)
|
|
106
|
+
|
|
107
|
+
reactor_thread = Thread(target=self.run_reactor, daemon=True)
|
|
108
|
+
reactor_thread.start()
|
|
109
|
+
|
|
110
|
+
def send_after_queue_empty(self, data):
|
|
111
|
+
while not self._ready_to_exit():
|
|
112
|
+
self.add_to_sending_queue(data)
|
|
113
|
+
time.sleep(2)
|
|
114
|
+
|
|
115
|
+
def check_client_alive(self):
|
|
116
|
+
return self.factory.num_connections > 0
|
|
117
|
+
|
|
118
|
+
def stop(self):
|
|
119
|
+
self.tcp_manager.connection_timeout()
|
|
120
|
+
|
|
121
|
+
def send_stop_signal(self):
|
|
122
|
+
self.send_after_queue_empty(self.ACK_STOP)
|
|
123
|
+
while not self._ready_to_exit():
|
|
124
|
+
if not self.check_client_alive():
|
|
125
|
+
break
|
|
126
|
+
time.sleep(1)
|
|
127
|
+
while not self.tcp_manager.kill_process:
|
|
128
|
+
time.sleep(1)
|
|
129
|
+
|
|
130
|
+
def add_to_sending_queue(self, data: Union[bytes, TCPDataItem], rank: int = 0, step: int = 0):
|
|
131
|
+
if self._ready_to_exit():
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
send_data = data
|
|
135
|
+
if not isinstance(data, TCPDataItem):
|
|
136
|
+
send_data = TCPDataItem(data=data,
|
|
137
|
+
sequence_number=self.sequence_number,
|
|
138
|
+
rank=rank,
|
|
139
|
+
step=step)
|
|
140
|
+
self.sequence_number += 1
|
|
141
|
+
try:
|
|
142
|
+
self.send_queue.put(send_data, block=True, timeout=self.QUEUE_PENDING_TIME)
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.error(f"send_queue put send_data timeout, rank: {send_data.rank}, step: {send_data.step},"
|
|
145
|
+
f"sequence_number: {send_data.sequence_number}, {str(e)}")
|
|
146
|
+
|
|
147
|
+
def _send_data(self, data: TCPDataItem):
|
|
148
|
+
self.tcp_manager.send_wrapped_data(data.raw_data,
|
|
149
|
+
sequence_number=data.sequence_number,
|
|
150
|
+
rank=data.rank,
|
|
151
|
+
step=data.step
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def _sending_queue_data(self):
|
|
155
|
+
while True:
|
|
156
|
+
if not self.tcp_manager.is_connected:
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
while self.send_queue.qsize() > 0:
|
|
160
|
+
if self._ready_to_exit():
|
|
161
|
+
break
|
|
162
|
+
if len(self.resend_dict) < self.MAX_SENDING_QUEUE_SIZE:
|
|
163
|
+
data_obj = self.send_queue.get()
|
|
164
|
+
self._send_data(data_obj)
|
|
165
|
+
resend_key = str(data_obj.sequence_number) + "_" + str(data_obj.rank) + "_" + str(data_obj.step)
|
|
166
|
+
if resend_key not in self.resend_dict.keys():
|
|
167
|
+
# Send data for the first time
|
|
168
|
+
self.resend_dict[resend_key] = data_obj
|
|
169
|
+
else:
|
|
170
|
+
time.sleep(0.1)
|
|
171
|
+
|
|
172
|
+
if self._ready_to_exit():
|
|
173
|
+
logger.debug("Successfully close sending process.")
|
|
174
|
+
break
|
|
175
|
+
time.sleep(0.1)
|
|
176
|
+
|
|
177
|
+
def _destroy_queue_data(self):
|
|
178
|
+
while True:
|
|
179
|
+
if self._ready_to_exit():
|
|
180
|
+
break
|
|
181
|
+
|
|
182
|
+
while len(self.resend_dict) > 0 and self.tcp_manager.ack_queue.qsize() > 0:
|
|
183
|
+
ack_info, seq_number, rank, step = self.tcp_manager.ack_queue.get()
|
|
184
|
+
obj_key = str(seq_number) + "_" + str(rank) + "_" + str(step)
|
|
185
|
+
current_item = self.resend_dict.get(obj_key)
|
|
186
|
+
|
|
187
|
+
if current_item is None:
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
if ack_info == self.ACK_SUCCESS:
|
|
191
|
+
self.resend_dict.pop(obj_key)
|
|
192
|
+
elif ack_info == self.ACK_BUSY:
|
|
193
|
+
logger.debug("RECV BUSY ACK")
|
|
194
|
+
if current_item.busy_time > 5:
|
|
195
|
+
self._resend_data(current_item)
|
|
196
|
+
else:
|
|
197
|
+
current_item.busy_time += 1
|
|
198
|
+
elif ack_info == self.ACK_ERROR:
|
|
199
|
+
logger.debug("RECV ERROR ACK")
|
|
200
|
+
self._resend_data(current_item)
|
|
201
|
+
elif ack_info == self.ACK_STOP_CONFIRM:
|
|
202
|
+
logger.debug("RECV STOP ACK")
|
|
203
|
+
self.factory.num_connections -= 1
|
|
204
|
+
|
|
205
|
+
break
|
|
206
|
+
|
|
207
|
+
time.sleep(0.1)
|
|
208
|
+
|
|
209
|
+
def _resend_data(self, data: TCPDataItem):
|
|
210
|
+
if data.retry_times < self.RESEND_RETRY_TIMES:
|
|
211
|
+
data.retry_times += 1
|
|
212
|
+
logger.debug(f"Resend data seq number: {data.sequence_number}")
|
|
213
|
+
self.add_to_sending_queue(data)
|
|
214
|
+
else:
|
|
215
|
+
self.resend_dict.pop(data.sequence_number)
|
|
216
|
+
logger.debug(f"SKIP send sequence number {data.sequence_number} after retry {data.retry_times} times!")
|
|
217
|
+
|
|
218
|
+
def _pending_data(self, data: TCPDataItem):
|
|
219
|
+
if data.pending_time >= self.RESEND_PENDING_TIME:
|
|
220
|
+
self.resend_dict.pop(data.sequence_number)
|
|
221
|
+
logger.debug(f"SKIP send sequence number {data.sequence_number} after pending {data.pending_time} times!")
|
|
222
|
+
return
|
|
223
|
+
|
|
224
|
+
# wait time is 100MB per second
|
|
225
|
+
pending_time = max(1, len(data.raw_data) // (2 ** 20 * 50))
|
|
226
|
+
data.pending_time += pending_time
|
|
227
|
+
time.sleep(pending_time)
|
|
228
|
+
|
|
229
|
+
def _ready_to_exit(self):
|
|
230
|
+
return self.signal_exit or self.tcp_manager.signal_exit
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class ClientProtocol(protocol.Protocol):
|
|
234
|
+
TIMEOUT = 60 * 10
|
|
235
|
+
|
|
236
|
+
def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False):
|
|
237
|
+
self.buffer = io.BytesIO()
|
|
238
|
+
self.is_connected = False
|
|
239
|
+
self.check_sum = check_sum
|
|
240
|
+
self.tell = 0
|
|
241
|
+
self.ack_queue = Queue(maxsize=ack_queue_size)
|
|
242
|
+
self.file_sender = FileSender()
|
|
243
|
+
self.file_sender.CHUNK_SIZE = chunk_size
|
|
244
|
+
self.signal_exit = False
|
|
245
|
+
self.defer = None
|
|
246
|
+
self.kill_process = False
|
|
247
|
+
|
|
248
|
+
def dataReceived(self, data):
|
|
249
|
+
if self.timeout_call.active():
|
|
250
|
+
self.timeout_call.reset(self.TIMEOUT)
|
|
251
|
+
|
|
252
|
+
self.buffer.seek(0, 2)
|
|
253
|
+
self.buffer.write(data)
|
|
254
|
+
self.buffer.seek(self.tell)
|
|
255
|
+
while True:
|
|
256
|
+
if len(self.buffer.getvalue()) >= 29: # 5 + 8 * 3
|
|
257
|
+
ack = self.buffer.read(5)
|
|
258
|
+
seq_number = struct.unpack('!Q', self.buffer.read(8))[0]
|
|
259
|
+
rank = struct.unpack('!Q', self.buffer.read(8))[0]
|
|
260
|
+
step = struct.unpack('!Q', self.buffer.read(8))[0]
|
|
261
|
+
if ack == b"KILL_":
|
|
262
|
+
self.kill_process = True
|
|
263
|
+
logger.debug(f"接收到KILL信号, PID {os.getpid()}")
|
|
264
|
+
if ack == b"OVER_":
|
|
265
|
+
self.factory.num_connections -= 1
|
|
266
|
+
self.tell += 29
|
|
267
|
+
if not self.ack_queue.full():
|
|
268
|
+
self.ack_queue.put((ack, seq_number, rank, step))
|
|
269
|
+
self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:])
|
|
270
|
+
self.tell = 0
|
|
271
|
+
else:
|
|
272
|
+
time.sleep(0.1)
|
|
273
|
+
else:
|
|
274
|
+
break
|
|
275
|
+
|
|
276
|
+
def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0):
|
|
277
|
+
length = len(data)
|
|
278
|
+
md5_hash = hashlib.md5(data).hexdigest() if self.check_sum else ""
|
|
279
|
+
while True:
|
|
280
|
+
if self.defer is None or self.defer.called:
|
|
281
|
+
self.defer = self.send_large_data(
|
|
282
|
+
length.to_bytes(8, byteorder='big') +
|
|
283
|
+
sequence_number.to_bytes(8, byteorder='big') +
|
|
284
|
+
rank.to_bytes(8, byteorder='big') +
|
|
285
|
+
step.to_bytes(8, byteorder='big') +
|
|
286
|
+
md5_hash.encode() +
|
|
287
|
+
data)
|
|
288
|
+
break
|
|
289
|
+
time.sleep(0.01)
|
|
290
|
+
|
|
291
|
+
def send_large_data(self, data):
|
|
292
|
+
d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport)
|
|
293
|
+
return d
|
|
294
|
+
|
|
295
|
+
def connection_timeout(self):
|
|
296
|
+
if self.factory.num_connections <= 0:
|
|
297
|
+
return
|
|
298
|
+
|
|
299
|
+
self.factory.num_connections -= 1
|
|
300
|
+
logger.debug(f"超时退出{self.transport.addr}, PID {os.getpid()}")
|
|
301
|
+
self.transport.loseConnection()
|
|
302
|
+
|
|
303
|
+
def connectionMade(self):
|
|
304
|
+
self.timeout_call = reactor.callLater(self.TIMEOUT, self.connection_timeout)
|
|
305
|
+
self.is_connected = True
|
|
306
|
+
self.factory.num_connections += 1
|
|
307
|
+
logger.info("successfully connect server")
|
|
308
|
+
|
|
309
|
+
def connectionLost(self, reason):
|
|
310
|
+
self.signal_exit = True
|
|
311
|
+
self.factory.num_connections -= 1
|
|
312
|
+
logger.info(f"Lost connection with server, reason is : {reason}")
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class MessageClientFactory(protocol.ClientFactory):
|
|
316
|
+
def __init__(self):
|
|
317
|
+
self.num_connections = 0
|
|
318
|
+
|
|
319
|
+
def clientConnectionFailed(self, connector, reason):
|
|
320
|
+
logger.info(f"Fail to connection with server: {reason.getErrorMessage()}")
|
|
321
|
+
reactor.stop()
|
|
322
|
+
|
|
323
|
+
def clientConnectionLost(self, connector, reason):
|
|
324
|
+
logger.info(f"Client lost connection with server: {reason.getErrorMessage()}")
|
|
325
|
+
reactor.stop()
|