mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import os.path
|
|
3
|
+
import time
|
|
4
|
+
import re
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from multiprocessing import Queue
|
|
7
|
+
from typing import Optional, Union, Dict, Any
|
|
8
|
+
from collections import namedtuple
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
|
|
14
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
|
|
15
|
+
from msprobe.pytorch.common.utils import logger
|
|
16
|
+
from msprobe.pytorch.common.utils import save_pt
|
|
17
|
+
from msprobe.core.common.utils import remove_path
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
ApiData = namedtuple('ApiData', ['name', 'args', 'kwargs', 'result', 'step', 'rank'],
|
|
21
|
+
defaults=['unknown', None, None, None, 0, 0])
|
|
22
|
+
BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class ATTLConfig:
|
|
27
|
+
is_benchmark_device: bool
|
|
28
|
+
connect_ip: str
|
|
29
|
+
connect_port: int
|
|
30
|
+
# storage_config
|
|
31
|
+
nfs_path: str = None
|
|
32
|
+
tls_path: str = None
|
|
33
|
+
check_sum: bool = True
|
|
34
|
+
queue_size: int = 50
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ATTL:
|
|
38
|
+
def __init__(self, session_id: str, session_config: ATTLConfig, need_dump=True) -> None:
|
|
39
|
+
self.session_id = session_id
|
|
40
|
+
self.session_config = session_config
|
|
41
|
+
self.logger = logger
|
|
42
|
+
self.socket_manager = None
|
|
43
|
+
self.data_queue = Queue(maxsize=50)
|
|
44
|
+
self.dequeue_list = []
|
|
45
|
+
self.message_end = False
|
|
46
|
+
self.kill_progress = False
|
|
47
|
+
self.check_attl_config()
|
|
48
|
+
if self.session_config.nfs_path:
|
|
49
|
+
self.nfs_path = Path(self.session_config.nfs_path)
|
|
50
|
+
elif self.session_config.is_benchmark_device:
|
|
51
|
+
|
|
52
|
+
self.socket_manager = TCPServer(self.session_config.connect_port,
|
|
53
|
+
self.data_queue,
|
|
54
|
+
self.session_config.check_sum,
|
|
55
|
+
self.session_config.tls_path)
|
|
56
|
+
self.socket_manager.start()
|
|
57
|
+
elif need_dump:
|
|
58
|
+
self.socket_manager = TCPClient(self.session_config.connect_ip,
|
|
59
|
+
self.session_config.connect_port,
|
|
60
|
+
self.session_config.check_sum,
|
|
61
|
+
self.session_config.tls_path)
|
|
62
|
+
self.socket_manager.start()
|
|
63
|
+
|
|
64
|
+
def check_attl_config(self):
|
|
65
|
+
if self.session_config.nfs_path:
|
|
66
|
+
if os.path.exists(self.session_config.nfs_path):
|
|
67
|
+
return
|
|
68
|
+
else:
|
|
69
|
+
raise Exception(f"nfs path {self.session_config.nfs_path} doesn't exists.")
|
|
70
|
+
ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
|
|
71
|
+
if not re.match(ipv4_pattern, self.session_config.connect_ip):
|
|
72
|
+
raise Exception(f"host {self.session_config.connect_ip} is invalid.")
|
|
73
|
+
if not (0 < self.session_config.connect_port <= 65535):
|
|
74
|
+
raise Exception(f"port {self.session_config.connect_port} is invalid.")
|
|
75
|
+
|
|
76
|
+
def stop_serve(self):
|
|
77
|
+
if isinstance(self.socket_manager, TCPServer):
|
|
78
|
+
self.socket_manager.stop()
|
|
79
|
+
|
|
80
|
+
def send(self, buffer: BufferType) -> None:
|
|
81
|
+
"""
|
|
82
|
+
npu major in 'send' (client)
|
|
83
|
+
"""
|
|
84
|
+
# know receiver receive and go next
|
|
85
|
+
if isinstance(buffer, ApiData):
|
|
86
|
+
buffer = move2target_device(buffer, torch.device('cpu'))
|
|
87
|
+
|
|
88
|
+
if 'device' in buffer.kwargs:
|
|
89
|
+
buffer.kwargs.pop('device')
|
|
90
|
+
rank = buffer.rank if hasattr(buffer, "rank") and buffer.rank is not None else 0
|
|
91
|
+
step = buffer.step if hasattr(buffer, "step") else 0
|
|
92
|
+
io_buff = io.BytesIO()
|
|
93
|
+
try:
|
|
94
|
+
torch.save(buffer, io_buff)
|
|
95
|
+
except Exception as e:
|
|
96
|
+
self.logger.info(f"{buffer.name} can not be saved, skip: {e}")
|
|
97
|
+
return
|
|
98
|
+
data = io_buff.getvalue()
|
|
99
|
+
self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
|
|
100
|
+
|
|
101
|
+
def recv(self, timeout_ms=0) -> Optional[BufferType]:
|
|
102
|
+
buffer = None
|
|
103
|
+
while buffer is None:
|
|
104
|
+
if timeout_ms > 0:
|
|
105
|
+
time.sleep(timeout_ms / 1000.0)
|
|
106
|
+
if buffer is None and not self.data_queue.empty():
|
|
107
|
+
buffer = self.data_queue.get()
|
|
108
|
+
break
|
|
109
|
+
if buffer is None and timeout_ms > 0: # timeout is the only case we give up and return None
|
|
110
|
+
break
|
|
111
|
+
if self.message_end and self.data_queue.empty():
|
|
112
|
+
buffer = b"KILL_CONFIRM"
|
|
113
|
+
self.kill_progress = True
|
|
114
|
+
break
|
|
115
|
+
time.sleep(0.1) # waiting outside the lock before next attempt
|
|
116
|
+
if buffer is None:
|
|
117
|
+
# this is a result of a timeout
|
|
118
|
+
self.logger.info(f"RECEIVE API DATA TIMED OUT")
|
|
119
|
+
else:
|
|
120
|
+
if buffer == b"STOP_":
|
|
121
|
+
return "STOP_"
|
|
122
|
+
if buffer == b"KILL_":
|
|
123
|
+
self.message_end = True
|
|
124
|
+
return "STOP_"
|
|
125
|
+
if buffer == b"KILL_CONFIRM":
|
|
126
|
+
self.kill_progress = True
|
|
127
|
+
return "KILL_"
|
|
128
|
+
buffer = io.BytesIO(buffer)
|
|
129
|
+
try:
|
|
130
|
+
buffer = torch.load(buffer, map_location="cpu")
|
|
131
|
+
except Exception as e:
|
|
132
|
+
self.logger.warning("there is something error. please check it. %s", e)
|
|
133
|
+
if isinstance(buffer, bytes):
|
|
134
|
+
return None
|
|
135
|
+
if isinstance(buffer, str):
|
|
136
|
+
return buffer
|
|
137
|
+
|
|
138
|
+
return buffer
|
|
139
|
+
|
|
140
|
+
def upload(self, buffer: BufferType):
|
|
141
|
+
if isinstance(buffer, ApiData):
|
|
142
|
+
buffer = move2target_device(buffer, torch.device('cpu'))
|
|
143
|
+
file_path = os.path.join(self.session_config.nfs_path, buffer.name + ".pt")
|
|
144
|
+
else:
|
|
145
|
+
file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
save_pt(buffer, file_path)
|
|
149
|
+
except Exception as e:
|
|
150
|
+
self.logger.warning("there is something error in save_pt. please check it. %s", e)
|
|
151
|
+
|
|
152
|
+
def download(self):
|
|
153
|
+
for file_type in ("start*", "*.pt", "end*"):
|
|
154
|
+
cur_file = next(self.nfs_path.glob(file_type), None)
|
|
155
|
+
if cur_file is not None:
|
|
156
|
+
break
|
|
157
|
+
|
|
158
|
+
if cur_file is None:
|
|
159
|
+
return None
|
|
160
|
+
else:
|
|
161
|
+
buffer = None
|
|
162
|
+
try:
|
|
163
|
+
buffer = torch.load(cur_file)
|
|
164
|
+
except Exception as e:
|
|
165
|
+
self.logger.warning("there is something error. please check it. %s", e)
|
|
166
|
+
remove_path(cur_file)
|
|
167
|
+
return buffer
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def move2device_exec(obj, device):
|
|
171
|
+
if isinstance(obj, (tuple, list)):
|
|
172
|
+
data_list = [move2device_exec(val, device) for val in obj]
|
|
173
|
+
return data_list if isinstance(obj, list) else tuple(data_list)
|
|
174
|
+
if isinstance(obj, dict):
|
|
175
|
+
return {key: move2device_exec(val, device) for key, val in obj.items()}
|
|
176
|
+
elif isinstance(obj, torch.Tensor):
|
|
177
|
+
obj = obj.detach()
|
|
178
|
+
if obj.device.type != device:
|
|
179
|
+
obj = obj.to(device)
|
|
180
|
+
return obj
|
|
181
|
+
elif "return_types" in str(type(obj)):
|
|
182
|
+
return move2device_exec(tuple(obj), device)
|
|
183
|
+
elif isinstance(obj, torch._C.device):
|
|
184
|
+
return torch.device(device)
|
|
185
|
+
else:
|
|
186
|
+
return obj
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def move2target_device(buffer: ApiData, target_device):
|
|
190
|
+
# handle args
|
|
191
|
+
new_args = move2device_exec(buffer.args, target_device)
|
|
192
|
+
|
|
193
|
+
# handle kwargs
|
|
194
|
+
new_kwargs = move2device_exec(buffer.kwargs, target_device)
|
|
195
|
+
|
|
196
|
+
# handle result
|
|
197
|
+
new_results = move2device_exec(buffer.result, target_device)
|
|
198
|
+
|
|
199
|
+
if target_device == torch.device('cpu') or target_device == "cpu":
|
|
200
|
+
return ApiData(buffer.name, tuple(new_args), new_kwargs, new_results, buffer.step, buffer.rank)
|
|
201
|
+
else:
|
|
202
|
+
return ApiData(buffer.name, tuple(new_args), new_kwargs, buffer.result, buffer.step, buffer.rank)
|
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import io
|
|
3
|
+
import struct
|
|
4
|
+
import time
|
|
5
|
+
import os
|
|
6
|
+
import signal
|
|
7
|
+
import sys
|
|
8
|
+
from queue import Queue
|
|
9
|
+
from threading import Thread
|
|
10
|
+
from typing import Union
|
|
11
|
+
|
|
12
|
+
from OpenSSL import SSL
|
|
13
|
+
from twisted.internet import ssl, reactor, protocol, endpoints
|
|
14
|
+
from twisted.protocols.basic import FileSender
|
|
15
|
+
|
|
16
|
+
from msprobe.pytorch.common.utils import logger
|
|
17
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.ssl_config import cipher_list
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TCPDataItem:
|
|
21
|
+
def __init__(self, data,
|
|
22
|
+
sequence_number: int,
|
|
23
|
+
rank: int = 0,
|
|
24
|
+
step: int = 0):
|
|
25
|
+
self.raw_data = data
|
|
26
|
+
self.sequence_number = sequence_number
|
|
27
|
+
self.rank = rank
|
|
28
|
+
self.step = step
|
|
29
|
+
self.retry_times = 0
|
|
30
|
+
self.pending_time = 0
|
|
31
|
+
self.busy_time = 0
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TCPClient:
|
|
35
|
+
MAX_SENDING_QUEUE_SIZE = 20
|
|
36
|
+
ACK_SUCCESS = b"OK___"
|
|
37
|
+
ACK_ERROR = b"ERROR"
|
|
38
|
+
ACK_BUSY = b"BUSY_"
|
|
39
|
+
ACK_STOP = b"STOP_"
|
|
40
|
+
ACK_STOP_CONFIRM = b"OVER_"
|
|
41
|
+
ACK_KILL_PROCESS = b"KILL_"
|
|
42
|
+
|
|
43
|
+
QUEUE_PENDING_TIME = 600 # 队列10分钟都处于阻塞状态,则终止sending进程
|
|
44
|
+
RESEND_RETRY_TIMES = 2 # 最大重传数
|
|
45
|
+
RESEND_TIMER_TIME = 5 # 接收ACK超时定时器
|
|
46
|
+
RESEND_PENDING_TIME = 60 # 连续pending时间超过1分钟则放弃该数据
|
|
47
|
+
|
|
48
|
+
def __init__(self, host="localhost", port=8000, check_sum=False, tls_path=None):
|
|
49
|
+
self.send_queue = Queue(self.MAX_SENDING_QUEUE_SIZE)
|
|
50
|
+
self.resend_dict = dict()
|
|
51
|
+
self.host = host
|
|
52
|
+
self.port = port
|
|
53
|
+
self.tls_path = tls_path
|
|
54
|
+
self.factory = None
|
|
55
|
+
self.sequence_number = 0
|
|
56
|
+
self.signal_exit = False
|
|
57
|
+
self.tcp_manager = ClientProtocol(ack_queue_size=100,
|
|
58
|
+
chunk_size=655360,
|
|
59
|
+
check_sum=check_sum)
|
|
60
|
+
self.send_thread = Thread(target=self._sending_queue_data)
|
|
61
|
+
self.send_thread.setDaemon(True)
|
|
62
|
+
self.send_thread.start()
|
|
63
|
+
self.destroy_thread = Thread(target=self._destroy_queue_data)
|
|
64
|
+
self.destroy_thread.setDaemon(True)
|
|
65
|
+
self.destroy_thread.start()
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def run_reactor():
|
|
69
|
+
reactor.run(installSignalHandlers=False)
|
|
70
|
+
|
|
71
|
+
def start(self):
|
|
72
|
+
def conn_callback(cur_protocol):
|
|
73
|
+
if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host:
|
|
74
|
+
logger.debug(f"Process: {os.getpid()} connects to server successfully.")
|
|
75
|
+
else:
|
|
76
|
+
logger.warning(f"Process: {os.getpid()} fails to connect to server. ")
|
|
77
|
+
raise ConnectionError(f"Failed to connect to {self.host}.")
|
|
78
|
+
|
|
79
|
+
def conn_err_callback(failure):
|
|
80
|
+
self.signal_exit = True
|
|
81
|
+
time.sleep(1)
|
|
82
|
+
reactor.stop()
|
|
83
|
+
logger.error(f"Failed to connected {self.host} {self.port}. Reason is {failure.getErrorMessage()}")
|
|
84
|
+
os.kill(os.getpid(), signal.SIGKILL)
|
|
85
|
+
os.kill(os.getppid(), signal.SIGKILL)
|
|
86
|
+
|
|
87
|
+
def cur_protocol():
|
|
88
|
+
return self.tcp_manager
|
|
89
|
+
|
|
90
|
+
self.factory = MessageClientFactory()
|
|
91
|
+
self.factory.protocol = cur_protocol
|
|
92
|
+
if self.tls_path:
|
|
93
|
+
client_key = os.path.join(self.tls_path, "client.key")
|
|
94
|
+
client_crt = os.path.join(self.tls_path, "client.crt")
|
|
95
|
+
client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt, SSL.TLSv1_2_METHOD)
|
|
96
|
+
client_context_ = client_context_factory.getContext()
|
|
97
|
+
client_context_.set_cipher_list(cipher_list)
|
|
98
|
+
client_context_.set_options(SSL.OP_NO_RENEGOTIATION)
|
|
99
|
+
endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory)
|
|
100
|
+
else:
|
|
101
|
+
endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port)
|
|
102
|
+
d = endpoint.connect(self.factory)
|
|
103
|
+
d.addCallback(conn_callback)
|
|
104
|
+
d.addErrback(conn_err_callback)
|
|
105
|
+
|
|
106
|
+
reactor_thread = Thread(target=self.run_reactor, daemon=True)
|
|
107
|
+
reactor_thread.start()
|
|
108
|
+
|
|
109
|
+
def send_after_queue_empty(self, data):
|
|
110
|
+
while not self._ready_to_exit():
|
|
111
|
+
self.add_to_sending_queue(data)
|
|
112
|
+
time.sleep(2)
|
|
113
|
+
|
|
114
|
+
def check_client_alive(self):
|
|
115
|
+
return self.factory.num_connections > 0
|
|
116
|
+
|
|
117
|
+
def stop(self):
|
|
118
|
+
self.tcp_manager.connection_timeout()
|
|
119
|
+
|
|
120
|
+
def send_stop_signal(self):
|
|
121
|
+
self.send_after_queue_empty(self.ACK_STOP)
|
|
122
|
+
while not self._ready_to_exit():
|
|
123
|
+
if not self.check_client_alive():
|
|
124
|
+
break
|
|
125
|
+
time.sleep(1)
|
|
126
|
+
while not self.tcp_manager.kill_process:
|
|
127
|
+
time.sleep(1)
|
|
128
|
+
|
|
129
|
+
def add_to_sending_queue(self, data: Union[bytes, TCPDataItem], rank: int = 0, step: int = 0):
|
|
130
|
+
if self._ready_to_exit():
|
|
131
|
+
return
|
|
132
|
+
|
|
133
|
+
send_data = data
|
|
134
|
+
if not isinstance(data, TCPDataItem):
|
|
135
|
+
send_data = TCPDataItem(data=data,
|
|
136
|
+
sequence_number=self.sequence_number,
|
|
137
|
+
rank=rank,
|
|
138
|
+
step=step)
|
|
139
|
+
self.sequence_number += 1
|
|
140
|
+
try:
|
|
141
|
+
self.send_queue.put(send_data, block=True, timeout=self.QUEUE_PENDING_TIME)
|
|
142
|
+
except Exception as e:
|
|
143
|
+
logger.error(f"send_queue put send_data timeout, rank: {send_data.rank}, step: {send_data.step},"
|
|
144
|
+
f"sequence_number: {send_data.sequence_number}, {str(e)}")
|
|
145
|
+
|
|
146
|
+
def _send_data(self, data: TCPDataItem):
|
|
147
|
+
self.tcp_manager.send_wrapped_data(data.raw_data,
|
|
148
|
+
sequence_number=data.sequence_number,
|
|
149
|
+
rank=data.rank,
|
|
150
|
+
step=data.step
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
def _sending_queue_data(self):
|
|
154
|
+
while True:
|
|
155
|
+
if not self.tcp_manager.is_connected:
|
|
156
|
+
continue
|
|
157
|
+
|
|
158
|
+
while self.send_queue.qsize() > 0:
|
|
159
|
+
if self._ready_to_exit():
|
|
160
|
+
break
|
|
161
|
+
if len(self.resend_dict) < self.MAX_SENDING_QUEUE_SIZE:
|
|
162
|
+
data_obj = self.send_queue.get()
|
|
163
|
+
self._send_data(data_obj)
|
|
164
|
+
resend_key = str(data_obj.sequence_number) + "_" + str(data_obj.rank) + "_" + str(data_obj.step)
|
|
165
|
+
if resend_key not in self.resend_dict.keys():
|
|
166
|
+
# Send data for the first time
|
|
167
|
+
self.resend_dict[resend_key] = data_obj
|
|
168
|
+
else:
|
|
169
|
+
time.sleep(0.1)
|
|
170
|
+
|
|
171
|
+
if self._ready_to_exit():
|
|
172
|
+
logger.debug("Successfully close sending process.")
|
|
173
|
+
break
|
|
174
|
+
time.sleep(0.1)
|
|
175
|
+
|
|
176
|
+
def _destroy_queue_data(self):
|
|
177
|
+
while True:
|
|
178
|
+
if self._ready_to_exit():
|
|
179
|
+
break
|
|
180
|
+
|
|
181
|
+
while len(self.resend_dict) > 0 and self.tcp_manager.ack_queue.qsize() > 0:
|
|
182
|
+
ack_info, seq_number, rank, step = self.tcp_manager.ack_queue.get()
|
|
183
|
+
obj_key = str(seq_number) + "_" + str(rank) + "_" + str(step)
|
|
184
|
+
current_item = self.resend_dict.get(obj_key)
|
|
185
|
+
|
|
186
|
+
if current_item is None:
|
|
187
|
+
continue
|
|
188
|
+
|
|
189
|
+
if ack_info == self.ACK_SUCCESS:
|
|
190
|
+
self.resend_dict.pop(obj_key)
|
|
191
|
+
elif ack_info == self.ACK_BUSY:
|
|
192
|
+
logger.debug("RECV BUSY ACK")
|
|
193
|
+
if current_item.busy_time > 5:
|
|
194
|
+
self._resend_data(current_item)
|
|
195
|
+
else:
|
|
196
|
+
current_item.busy_time += 1
|
|
197
|
+
elif ack_info == self.ACK_ERROR:
|
|
198
|
+
logger.debug("RECV ERROR ACK")
|
|
199
|
+
self._resend_data(current_item)
|
|
200
|
+
elif ack_info == self.ACK_STOP_CONFIRM:
|
|
201
|
+
logger.debug("RECV STOP ACK")
|
|
202
|
+
self.factory.num_connections -= 1
|
|
203
|
+
|
|
204
|
+
break
|
|
205
|
+
|
|
206
|
+
time.sleep(0.1)
|
|
207
|
+
|
|
208
|
+
def _resend_data(self, data: TCPDataItem):
|
|
209
|
+
if data.retry_times < self.RESEND_RETRY_TIMES:
|
|
210
|
+
data.retry_times += 1
|
|
211
|
+
logger.debug(f"Resend data seq number: {data.sequence_number}")
|
|
212
|
+
self.add_to_sending_queue(data)
|
|
213
|
+
else:
|
|
214
|
+
self.resend_dict.pop(data.sequence_number)
|
|
215
|
+
logger.debug(f"SKIP send sequence number {data.sequence_number} after retry {data.retry_times} times!")
|
|
216
|
+
|
|
217
|
+
def _pending_data(self, data: TCPDataItem):
|
|
218
|
+
if data.pending_time >= self.RESEND_PENDING_TIME:
|
|
219
|
+
self.resend_dict.pop(data.sequence_number)
|
|
220
|
+
logger.debug(f"SKIP send sequence number {data.sequence_number} after pending {data.pending_time} times!")
|
|
221
|
+
return
|
|
222
|
+
|
|
223
|
+
# wait time is 100MB per second
|
|
224
|
+
pending_time = max(1, len(data.raw_data) // (2 ** 20 * 50))
|
|
225
|
+
data.pending_time += pending_time
|
|
226
|
+
time.sleep(pending_time)
|
|
227
|
+
|
|
228
|
+
def _ready_to_exit(self):
|
|
229
|
+
return self.signal_exit or self.tcp_manager.signal_exit
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class ClientProtocol(protocol.Protocol):
|
|
233
|
+
TIMEOUT = 60 * 10
|
|
234
|
+
|
|
235
|
+
def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False):
|
|
236
|
+
self.buffer = io.BytesIO()
|
|
237
|
+
self.is_connected = False
|
|
238
|
+
self.check_sum = check_sum
|
|
239
|
+
self.tell = 0
|
|
240
|
+
self.ack_queue = Queue(maxsize=ack_queue_size)
|
|
241
|
+
self.file_sender = FileSender()
|
|
242
|
+
self.file_sender.CHUNK_SIZE = chunk_size
|
|
243
|
+
self.signal_exit = False
|
|
244
|
+
self.defer = None
|
|
245
|
+
self.kill_process = False
|
|
246
|
+
|
|
247
|
+
def dataReceived(self, data):
|
|
248
|
+
if self.timeout_call.active():
|
|
249
|
+
self.timeout_call.reset(self.TIMEOUT)
|
|
250
|
+
|
|
251
|
+
self.buffer.seek(0, 2)
|
|
252
|
+
self.buffer.write(data)
|
|
253
|
+
self.buffer.seek(self.tell)
|
|
254
|
+
while True:
|
|
255
|
+
if len(self.buffer.getvalue()) >= 29: # 5 + 8 * 3
|
|
256
|
+
ack = self.buffer.read(5)
|
|
257
|
+
seq_number = struct.unpack('!Q', self.buffer.read(8))[0]
|
|
258
|
+
rank = struct.unpack('!Q', self.buffer.read(8))[0]
|
|
259
|
+
step = struct.unpack('!Q', self.buffer.read(8))[0]
|
|
260
|
+
if ack == b"KILL_":
|
|
261
|
+
self.kill_process = True
|
|
262
|
+
logger.debug(f"接收到KILL信号, PID {os.getpid()}")
|
|
263
|
+
if ack == b"OVER_":
|
|
264
|
+
self.factory.num_connections -= 1
|
|
265
|
+
self.tell += 29
|
|
266
|
+
if not self.ack_queue.full():
|
|
267
|
+
self.ack_queue.put((ack, seq_number, rank, step))
|
|
268
|
+
self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:])
|
|
269
|
+
self.tell = 0
|
|
270
|
+
else:
|
|
271
|
+
time.sleep(0.1)
|
|
272
|
+
else:
|
|
273
|
+
break
|
|
274
|
+
|
|
275
|
+
def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0):
|
|
276
|
+
length = len(data)
|
|
277
|
+
md5_hash = hashlib.md5(data).hexdigest() if self.check_sum else ""
|
|
278
|
+
while True:
|
|
279
|
+
if self.defer is None or self.defer.called:
|
|
280
|
+
self.defer = self.send_large_data(
|
|
281
|
+
length.to_bytes(8, byteorder='big') +
|
|
282
|
+
sequence_number.to_bytes(8, byteorder='big') +
|
|
283
|
+
rank.to_bytes(8, byteorder='big') +
|
|
284
|
+
step.to_bytes(8, byteorder='big') +
|
|
285
|
+
md5_hash.encode() +
|
|
286
|
+
data)
|
|
287
|
+
break
|
|
288
|
+
time.sleep(0.01)
|
|
289
|
+
|
|
290
|
+
def send_large_data(self, data):
|
|
291
|
+
d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport)
|
|
292
|
+
return d
|
|
293
|
+
|
|
294
|
+
def connection_timeout(self):
|
|
295
|
+
if self.factory.num_connections <= 0:
|
|
296
|
+
return
|
|
297
|
+
|
|
298
|
+
self.factory.num_connections -= 1
|
|
299
|
+
logger.debug(f"超时退出{self.transport.addr}, PID {os.getpid()}")
|
|
300
|
+
self.transport.loseConnection()
|
|
301
|
+
|
|
302
|
+
def connectionMade(self):
|
|
303
|
+
self.timeout_call = reactor.callLater(self.TIMEOUT, self.connection_timeout)
|
|
304
|
+
self.is_connected = True
|
|
305
|
+
self.factory.num_connections += 1
|
|
306
|
+
logger.info("successfully connect server")
|
|
307
|
+
|
|
308
|
+
def connectionLost(self, reason):
|
|
309
|
+
self.signal_exit = True
|
|
310
|
+
self.factory.num_connections -= 1
|
|
311
|
+
logger.info(f"Lost connection with server, reason is : {reason}")
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class MessageClientFactory(protocol.ClientFactory):
|
|
315
|
+
def __init__(self):
|
|
316
|
+
self.num_connections = 0
|
|
317
|
+
|
|
318
|
+
def clientConnectionFailed(self, connector, reason):
|
|
319
|
+
logger.info(f"Fail to connection with server: {reason.getErrorMessage()}")
|
|
320
|
+
reactor.stop()
|
|
321
|
+
|
|
322
|
+
def clientConnectionLost(self, connector, reason):
|
|
323
|
+
logger.info(f"Client lost connection with server: {reason.getErrorMessage()}")
|
|
324
|
+
reactor.stop()
|