mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,324 +1,370 @@
1
- import hashlib
2
- import io
3
- import struct
4
- import time
5
- import os
6
- import signal
7
- import sys
8
- from queue import Queue
9
- from threading import Thread
10
- from typing import Union
11
-
12
- from OpenSSL import SSL
13
- from twisted.internet import ssl, reactor, protocol, endpoints
14
- from twisted.protocols.basic import FileSender
15
-
16
- from msprobe.pytorch.common.utils import logger
17
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.ssl_config import cipher_list
18
-
19
-
20
- class TCPDataItem:
21
- def __init__(self, data,
22
- sequence_number: int,
23
- rank: int = 0,
24
- step: int = 0):
25
- self.raw_data = data
26
- self.sequence_number = sequence_number
27
- self.rank = rank
28
- self.step = step
29
- self.retry_times = 0
30
- self.pending_time = 0
31
- self.busy_time = 0
32
-
33
-
34
- class TCPClient:
35
- MAX_SENDING_QUEUE_SIZE = 20
36
- ACK_SUCCESS = b"OK___"
37
- ACK_ERROR = b"ERROR"
38
- ACK_BUSY = b"BUSY_"
39
- ACK_STOP = b"STOP_"
40
- ACK_STOP_CONFIRM = b"OVER_"
41
- ACK_KILL_PROCESS = b"KILL_"
42
-
43
- QUEUE_PENDING_TIME = 600 # 队列10分钟都处于阻塞状态,则终止sending进程
44
- RESEND_RETRY_TIMES = 2 # 最大重传数
45
- RESEND_TIMER_TIME = 5 # 接收ACK超时定时器
46
- RESEND_PENDING_TIME = 60 # 连续pending时间超过1分钟则放弃该数据
47
-
48
- def __init__(self, host="localhost", port=8000, check_sum=False, tls_path=None):
49
- self.send_queue = Queue(self.MAX_SENDING_QUEUE_SIZE)
50
- self.resend_dict = dict()
51
- self.host = host
52
- self.port = port
53
- self.tls_path = tls_path
54
- self.factory = None
55
- self.sequence_number = 0
56
- self.signal_exit = False
57
- self.tcp_manager = ClientProtocol(ack_queue_size=100,
58
- chunk_size=655360,
59
- check_sum=check_sum)
60
- self.send_thread = Thread(target=self._sending_queue_data)
61
- self.send_thread.setDaemon(True)
62
- self.send_thread.start()
63
- self.destroy_thread = Thread(target=self._destroy_queue_data)
64
- self.destroy_thread.setDaemon(True)
65
- self.destroy_thread.start()
66
-
67
- @staticmethod
68
- def run_reactor():
69
- reactor.run(installSignalHandlers=False)
70
-
71
- def start(self):
72
- def conn_callback(cur_protocol):
73
- if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host:
74
- logger.debug(f"Process: {os.getpid()} connects to server successfully.")
75
- else:
76
- logger.warning(f"Process: {os.getpid()} fails to connect to server. ")
77
- raise ConnectionError(f"Failed to connect to {self.host}.")
78
-
79
- def conn_err_callback(failure):
80
- self.signal_exit = True
81
- time.sleep(1)
82
- reactor.stop()
83
- logger.error(f"Failed to connected {self.host} {self.port}. Reason is {failure.getErrorMessage()}")
84
- os.kill(os.getpid(), signal.SIGKILL)
85
- os.kill(os.getppid(), signal.SIGKILL)
86
-
87
- def cur_protocol():
88
- return self.tcp_manager
89
-
90
- self.factory = MessageClientFactory()
91
- self.factory.protocol = cur_protocol
92
- if self.tls_path:
93
- client_key = os.path.join(self.tls_path, "client.key")
94
- client_crt = os.path.join(self.tls_path, "client.crt")
95
- client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt, SSL.TLSv1_2_METHOD)
96
- client_context_ = client_context_factory.getContext()
97
- client_context_.set_cipher_list(cipher_list)
98
- client_context_.set_options(SSL.OP_NO_RENEGOTIATION)
99
- endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory)
100
- else:
101
- endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port)
102
- d = endpoint.connect(self.factory)
103
- d.addCallback(conn_callback)
104
- d.addErrback(conn_err_callback)
105
-
106
- reactor_thread = Thread(target=self.run_reactor, daemon=True)
107
- reactor_thread.start()
108
-
109
- def send_after_queue_empty(self, data):
110
- while not self._ready_to_exit():
111
- self.add_to_sending_queue(data)
112
- time.sleep(2)
113
-
114
- def check_client_alive(self):
115
- return self.factory.num_connections > 0
116
-
117
- def stop(self):
118
- self.tcp_manager.connection_timeout()
119
-
120
- def send_stop_signal(self):
121
- self.send_after_queue_empty(self.ACK_STOP)
122
- while not self._ready_to_exit():
123
- if not self.check_client_alive():
124
- break
125
- time.sleep(1)
126
- while not self.tcp_manager.kill_process:
127
- time.sleep(1)
128
-
129
- def add_to_sending_queue(self, data: Union[bytes, TCPDataItem], rank: int = 0, step: int = 0):
130
- if self._ready_to_exit():
131
- return
132
-
133
- send_data = data
134
- if not isinstance(data, TCPDataItem):
135
- send_data = TCPDataItem(data=data,
136
- sequence_number=self.sequence_number,
137
- rank=rank,
138
- step=step)
139
- self.sequence_number += 1
140
- try:
141
- self.send_queue.put(send_data, block=True, timeout=self.QUEUE_PENDING_TIME)
142
- except Exception as e:
143
- logger.error(f"send_queue put send_data timeout, rank: {send_data.rank}, step: {send_data.step},"
144
- f"sequence_number: {send_data.sequence_number}, {str(e)}")
145
-
146
- def _send_data(self, data: TCPDataItem):
147
- self.tcp_manager.send_wrapped_data(data.raw_data,
148
- sequence_number=data.sequence_number,
149
- rank=data.rank,
150
- step=data.step
151
- )
152
-
153
- def _sending_queue_data(self):
154
- while True:
155
- if not self.tcp_manager.is_connected:
156
- continue
157
-
158
- while self.send_queue.qsize() > 0:
159
- if self._ready_to_exit():
160
- break
161
- if len(self.resend_dict) < self.MAX_SENDING_QUEUE_SIZE:
162
- data_obj = self.send_queue.get()
163
- self._send_data(data_obj)
164
- resend_key = str(data_obj.sequence_number) + "_" + str(data_obj.rank) + "_" + str(data_obj.step)
165
- if resend_key not in self.resend_dict.keys():
166
- # Send data for the first time
167
- self.resend_dict[resend_key] = data_obj
168
- else:
169
- time.sleep(0.1)
170
-
171
- if self._ready_to_exit():
172
- logger.debug("Successfully close sending process.")
173
- break
174
- time.sleep(0.1)
175
-
176
- def _destroy_queue_data(self):
177
- while True:
178
- if self._ready_to_exit():
179
- break
180
-
181
- while len(self.resend_dict) > 0 and self.tcp_manager.ack_queue.qsize() > 0:
182
- ack_info, seq_number, rank, step = self.tcp_manager.ack_queue.get()
183
- obj_key = str(seq_number) + "_" + str(rank) + "_" + str(step)
184
- current_item = self.resend_dict.get(obj_key)
185
-
186
- if current_item is None:
187
- continue
188
-
189
- if ack_info == self.ACK_SUCCESS:
190
- self.resend_dict.pop(obj_key)
191
- elif ack_info == self.ACK_BUSY:
192
- logger.debug("RECV BUSY ACK")
193
- if current_item.busy_time > 5:
194
- self._resend_data(current_item)
195
- else:
196
- current_item.busy_time += 1
197
- elif ack_info == self.ACK_ERROR:
198
- logger.debug("RECV ERROR ACK")
199
- self._resend_data(current_item)
200
- elif ack_info == self.ACK_STOP_CONFIRM:
201
- logger.debug("RECV STOP ACK")
202
- self.factory.num_connections -= 1
203
-
204
- break
205
-
206
- time.sleep(0.1)
207
-
208
- def _resend_data(self, data: TCPDataItem):
209
- if data.retry_times < self.RESEND_RETRY_TIMES:
210
- data.retry_times += 1
211
- logger.debug(f"Resend data seq number: {data.sequence_number}")
212
- self.add_to_sending_queue(data)
213
- else:
214
- self.resend_dict.pop(data.sequence_number)
215
- logger.debug(f"SKIP send sequence number {data.sequence_number} after retry {data.retry_times} times!")
216
-
217
- def _pending_data(self, data: TCPDataItem):
218
- if data.pending_time >= self.RESEND_PENDING_TIME:
219
- self.resend_dict.pop(data.sequence_number)
220
- logger.debug(f"SKIP send sequence number {data.sequence_number} after pending {data.pending_time} times!")
221
- return
222
-
223
- # wait time is 100MB per second
224
- pending_time = max(1, len(data.raw_data) // (2 ** 20 * 50))
225
- data.pending_time += pending_time
226
- time.sleep(pending_time)
227
-
228
- def _ready_to_exit(self):
229
- return self.signal_exit or self.tcp_manager.signal_exit
230
-
231
-
232
- class ClientProtocol(protocol.Protocol):
233
- TIMEOUT = 60 * 10
234
-
235
- def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False):
236
- self.buffer = io.BytesIO()
237
- self.is_connected = False
238
- self.check_sum = check_sum
239
- self.tell = 0
240
- self.ack_queue = Queue(maxsize=ack_queue_size)
241
- self.file_sender = FileSender()
242
- self.file_sender.CHUNK_SIZE = chunk_size
243
- self.signal_exit = False
244
- self.defer = None
245
- self.kill_process = False
246
-
247
- def dataReceived(self, data):
248
- if self.timeout_call.active():
249
- self.timeout_call.reset(self.TIMEOUT)
250
-
251
- self.buffer.seek(0, 2)
252
- self.buffer.write(data)
253
- self.buffer.seek(self.tell)
254
- while True:
255
- if len(self.buffer.getvalue()) >= 29: # 5 + 8 * 3
256
- ack = self.buffer.read(5)
257
- seq_number = struct.unpack('!Q', self.buffer.read(8))[0]
258
- rank = struct.unpack('!Q', self.buffer.read(8))[0]
259
- step = struct.unpack('!Q', self.buffer.read(8))[0]
260
- if ack == b"KILL_":
261
- self.kill_process = True
262
- logger.debug(f"接收到KILL信号, PID {os.getpid()}")
263
- if ack == b"OVER_":
264
- self.factory.num_connections -= 1
265
- self.tell += 29
266
- if not self.ack_queue.full():
267
- self.ack_queue.put((ack, seq_number, rank, step))
268
- self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:])
269
- self.tell = 0
270
- else:
271
- time.sleep(0.1)
272
- else:
273
- break
274
-
275
- def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0):
276
- length = len(data)
277
- md5_hash = hashlib.md5(data).hexdigest() if self.check_sum else ""
278
- while True:
279
- if self.defer is None or self.defer.called:
280
- self.defer = self.send_large_data(
281
- length.to_bytes(8, byteorder='big') +
282
- sequence_number.to_bytes(8, byteorder='big') +
283
- rank.to_bytes(8, byteorder='big') +
284
- step.to_bytes(8, byteorder='big') +
285
- md5_hash.encode() +
286
- data)
287
- break
288
- time.sleep(0.01)
289
-
290
- def send_large_data(self, data):
291
- d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport)
292
- return d
293
-
294
- def connection_timeout(self):
295
- if self.factory.num_connections <= 0:
296
- return
297
-
298
- self.factory.num_connections -= 1
299
- logger.debug(f"超时退出{self.transport.addr}, PID {os.getpid()}")
300
- self.transport.loseConnection()
301
-
302
- def connectionMade(self):
303
- self.timeout_call = reactor.callLater(self.TIMEOUT, self.connection_timeout)
304
- self.is_connected = True
305
- self.factory.num_connections += 1
306
- logger.info("successfully connect server")
307
-
308
- def connectionLost(self, reason):
309
- self.signal_exit = True
310
- self.factory.num_connections -= 1
311
- logger.info(f"Lost connection with server, reason is : {reason}")
312
-
313
-
314
- class MessageClientFactory(protocol.ClientFactory):
315
- def __init__(self):
316
- self.num_connections = 0
317
-
318
- def clientConnectionFailed(self, connector, reason):
319
- logger.info(f"Fail to connection with server: {reason.getErrorMessage()}")
320
- reactor.stop()
321
-
322
- def clientConnectionLost(self, connector, reason):
323
- logger.info(f"Client lost connection with server: {reason.getErrorMessage()}")
324
- reactor.stop()
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import hashlib
17
+ import io
18
+ import struct
19
+ import time
20
+ import os
21
+ import signal
22
+ from queue import Queue
23
+ from threading import Thread
24
+ from typing import Union
25
+
26
+ from twisted.internet import reactor, protocol, endpoints
27
+ from twisted.protocols.basic import FileSender
28
+
29
+ from msprobe.pytorch.common.utils import logger
30
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import struct_unpack_mode as unpack_mode, \
31
+ str_to_bytes_order as bytes_order
32
+
33
+ MAX_SENDING_QUEUE_SIZE = 20
34
+
35
+
36
+ class TCPDataItem:
37
+ def __init__(self, data,
38
+ sequence_number: int,
39
+ rank: int = 0,
40
+ step: int = 0):
41
+ self.raw_data = data
42
+ self.sequence_number = sequence_number
43
+ self.rank = rank
44
+ self.step = step
45
+ self.retry_times = 0
46
+ self.pending_time = 0
47
+ self.busy_time = 0
48
+
49
+
50
+ class TCPClient:
51
+ ACK_SUCCESS = b"OK___"
52
+ ACK_ERROR = b"ERROR"
53
+ ACK_BUSY = b"BUSY_"
54
+ ACK_STOP = b"STOP_"
55
+ ACK_STOP_CONFIRM = b"OVER_"
56
+ ACK_KILL_PROCESS = b"KILL_"
57
+
58
+ QUEUE_PENDING_TIME = 60
59
+ RESEND_RETRY_TIMES = 2 # 最大重传数
60
+ RESEND_TIMER_TIME = 5 # 接收ACK超时定时器
61
+ RESEND_PENDING_TIME = 60 # 连续pending时间超过1分钟则放弃该数据
62
+
63
+ def __init__(self, host="localhost", port=8000, check_sum=False, tls_path=None):
64
+ self.send_queue = Queue(MAX_SENDING_QUEUE_SIZE)
65
+ self.resend_dict = dict()
66
+ self.host = host
67
+ self.port = port
68
+ self.tls_path = tls_path
69
+ self.factory = None
70
+ self.sequence_number = 0
71
+ self.signal_exit = False
72
+ self.tcp_manager = ClientProtocol(ack_queue_size=100,
73
+ chunk_size=655360,
74
+ check_sum=check_sum,
75
+ tls=self.tls_path)
76
+ self.send_thread = Thread(target=self._sending_queue_data)
77
+ self.send_thread.setDaemon(True)
78
+ self.send_thread.start()
79
+ self.destroy_thread = Thread(target=self._destroy_queue_data)
80
+ self.destroy_thread.setDaemon(True)
81
+ self.destroy_thread.start()
82
+
83
+ @staticmethod
84
+ def run_reactor():
85
+ reactor.run(installSignalHandlers=False)
86
+
87
+ def check_tls_path(self):
88
+ client_key = os.path.join(self.tls_path, "client.key")
89
+ client_crt = os.path.join(self.tls_path, "client.crt")
90
+ if not os.path.exists(client_key):
91
+ raise Exception(f"client_key: {client_key} is not exists.")
92
+ if not os.path.exists(client_crt):
93
+ raise Exception(f"client_crt: {client_crt} is not exists.")
94
+ return client_key, client_crt
95
+
96
+ def start(self):
97
+ def conn_callback(cur_protocol):
98
+ if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host:
99
+ logger.debug(f"Process: {os.getpid()} connects to server successfully.")
100
+ else:
101
+ logger.warning(f"Process: {os.getpid()} fails to connect to server. ")
102
+ raise ConnectionError(f"Failed to connect to {self.host}.")
103
+
104
+ def conn_err_callback(failure):
105
+ self.signal_exit = True
106
+ time.sleep(1)
107
+ reactor.stop()
108
+ logger.error(f"Failed to connected {self.host} {self.port}. Reason is {failure.getErrorMessage()}")
109
+
110
+ def cur_protocol():
111
+ return self.tcp_manager
112
+
113
+ self.factory = MessageClientFactory()
114
+ self.factory.protocol = cur_protocol
115
+ if self.tls_path:
116
+ from twisted.internet import ssl
117
+ client_key, client_crt = self.check_tls_path()
118
+ client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt)
119
+ endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory)
120
+ else:
121
+ endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port)
122
+ d = endpoint.connect(self.factory)
123
+ d.addCallback(conn_callback)
124
+ d.addErrback(conn_err_callback)
125
+
126
+ reactor_thread = Thread(target=self.run_reactor, daemon=True)
127
+ reactor_thread.start()
128
+
129
+ def send_after_queue_empty(self, data):
130
+ while not self._ready_to_exit():
131
+ if not self.tls_path:
132
+ self.add_to_sending_queue(data)
133
+ else:
134
+ for _ in range(MAX_SENDING_QUEUE_SIZE):
135
+ self.add_to_sending_queue(data)
136
+ time.sleep(2)
137
+
138
+ def check_client_alive(self):
139
+ return self.factory.num_connections > 0
140
+
141
+ def stop(self):
142
+ self.tcp_manager.connection_timeout()
143
+
144
+ def send_stop_signal(self):
145
+ self.send_after_queue_empty(self.ACK_STOP)
146
+ while not self._ready_to_exit():
147
+ if not self.check_client_alive():
148
+ break
149
+ time.sleep(1)
150
+
151
+ def add_to_sending_queue(self, data: Union[bytes, TCPDataItem], rank: int = 0, step: int = 0):
152
+ if self._ready_to_exit():
153
+ return
154
+
155
+ send_data = data
156
+ if not isinstance(data, TCPDataItem):
157
+ send_data = TCPDataItem(data=data,
158
+ sequence_number=self.sequence_number,
159
+ rank=rank,
160
+ step=step)
161
+ self.sequence_number += 1
162
+ try:
163
+ self.send_queue.put(send_data, block=True, timeout=self.QUEUE_PENDING_TIME)
164
+ except Exception as e:
165
+ logger.error(f"send_queue put send_data timeout, rank: {send_data.rank}, step: {send_data.step},"
166
+ f"sequence_number: {send_data.sequence_number}, send_queue size: {self.send_queue.qsize()},"
167
+ f"{str(e)}")
168
+
169
+ def _send_data(self, data: TCPDataItem):
170
+ self.tcp_manager.send_wrapped_data(data.raw_data,
171
+ sequence_number=data.sequence_number,
172
+ rank=data.rank,
173
+ step=data.step
174
+ )
175
+
176
+ def _sending_queue_data(self):
177
+ while True:
178
+ if not self.tcp_manager.is_connected:
179
+ continue
180
+
181
+ while self.send_queue.qsize() > 0:
182
+ if self._ready_to_exit():
183
+ break
184
+ if len(self.resend_dict) < MAX_SENDING_QUEUE_SIZE:
185
+ data_obj = self.send_queue.get()
186
+ resend_key = str(data_obj.sequence_number) + "_" + str(data_obj.rank) + "_" + str(data_obj.step)
187
+ logger.debug(f"get {resend_key} from send_queue, and send to server.")
188
+ self._send_data(data_obj)
189
+ if resend_key not in self.resend_dict.keys():
190
+ # Send data for the first time
191
+ self.resend_dict[resend_key] = data_obj
192
+ else:
193
+ time.sleep(0.1)
194
+
195
+ if self._ready_to_exit():
196
+ logger.debug("Successfully close sending process.")
197
+ break
198
+ time.sleep(0.1)
199
+
200
+ def _destroy_queue_data(self):
201
+ while True:
202
+ if self._ready_to_exit():
203
+ break
204
+
205
+ while len(self.resend_dict) > 0 and self.tcp_manager.ack_queue.qsize() > 0:
206
+ ack_info, seq_number, rank, step = self.tcp_manager.ack_queue.get()
207
+ obj_key = str(seq_number) + "_" + str(rank) + "_" + str(step)
208
+ current_item = self.resend_dict.get(obj_key)
209
+
210
+ if current_item is None:
211
+ continue
212
+
213
+ if ack_info == self.ACK_SUCCESS:
214
+ self.resend_dict.pop(obj_key)
215
+ elif ack_info == self.ACK_BUSY:
216
+ logger.debug("RECV BUSY ACK")
217
+ if current_item.busy_time > 5:
218
+ self._resend_data(current_item)
219
+ else:
220
+ current_item.busy_time += 1
221
+ elif ack_info == self.ACK_ERROR:
222
+ logger.debug("RECV ERROR ACK")
223
+ self._resend_data(current_item)
224
+ elif ack_info == self.ACK_STOP_CONFIRM:
225
+ logger.debug("RECV STOP ACK")
226
+ self.factory.num_connections -= 1
227
+
228
+ break
229
+
230
+ time.sleep(0.1)
231
+
232
+ def _resend_data(self, data: TCPDataItem):
233
+ if data.retry_times < self.RESEND_RETRY_TIMES:
234
+ data.retry_times += 1
235
+ logger.debug(f"Resend data seq number: {data.sequence_number}")
236
+ self.add_to_sending_queue(data)
237
+ else:
238
+ self.resend_dict.pop(data.sequence_number)
239
+ logger.debug(f"SKIP send sequence number {data.sequence_number} after retry {data.retry_times} times!")
240
+
241
+ def _pending_data(self, data: TCPDataItem):
242
+ if data.pending_time >= self.RESEND_PENDING_TIME:
243
+ self.resend_dict.pop(data.sequence_number)
244
+ logger.debug(f"SKIP send sequence number {data.sequence_number} after pending {data.pending_time} times!")
245
+ return
246
+
247
+ # wait time is 100MB per second
248
+ pending_time = max(1, len(data.raw_data) // (2 ** 20 * 50))
249
+ data.pending_time += pending_time
250
+ time.sleep(pending_time)
251
+
252
+ def _ready_to_exit(self):
253
+ return self.signal_exit or self.tcp_manager.signal_exit
254
+
255
+
256
+ class ClientProtocol(protocol.Protocol):
257
+ TIMEOUT = 60 * 10
258
+
259
+ def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False, tls=None):
260
+ self.buffer = io.BytesIO()
261
+ self.is_connected = False
262
+ self.check_sum = check_sum
263
+ self.tell = 0
264
+ self.ack_queue = Queue(maxsize=ack_queue_size)
265
+ self.file_sender = FileSender()
266
+ self.file_sender.CHUNK_SIZE = chunk_size
267
+ self.signal_exit = False
268
+ self.defer = None
269
+ self.kill_process = False
270
+ self.ack = None
271
+
272
+ self.timeout_call = None
273
+
274
+ self.tls = tls
275
+ self.send_buffer = b""
276
+ self.buffer_cnt = 0
277
+
278
+ def dataReceived(self, data):
279
+ if self.timeout_call.active():
280
+ self.timeout_call.reset(self.TIMEOUT)
281
+
282
+ self.buffer.seek(0, 2)
283
+ self.buffer.write(data)
284
+ self.buffer.seek(self.tell)
285
+ while True:
286
+ if len(self.buffer.getvalue()) >= 29: # 5 + 8 * 3
287
+ ack = self.buffer.read(5)
288
+ self.ack = ack
289
+ seq_number = struct.unpack(unpack_mode, self.buffer.read(8))[0]
290
+ rank = struct.unpack(unpack_mode, self.buffer.read(8))[0]
291
+ step = struct.unpack(unpack_mode, self.buffer.read(8))[0]
292
+ logger.debug(f"receive 流水号: {seq_number}; RANK: {rank}; STEP: {step}; ACK: {ack}")
293
+ if ack == b"KILL_":
294
+ self.kill_process = True
295
+ logger.debug(f"接收到KILL信号, PID {os.getpid()}")
296
+ if ack == b"OVER_":
297
+ self.factory.num_connections -= 1
298
+ self.tell += 29
299
+ if not self.ack_queue.full():
300
+ self.ack_queue.put((ack, seq_number, rank, step))
301
+ self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:])
302
+ self.tell = 0
303
+ else:
304
+ time.sleep(0.1)
305
+ else:
306
+ break
307
+
308
+ def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0):
309
+ length = len(data)
310
+ md5_hash = hashlib.md5(data).hexdigest() if self.check_sum else ""
311
+ data_meaasge = length.to_bytes(8, byteorder=bytes_order) + \
312
+ sequence_number.to_bytes(8, byteorder=bytes_order) + \
313
+ rank.to_bytes(8, byteorder=bytes_order) + \
314
+ step.to_bytes(8, byteorder=bytes_order) + \
315
+ md5_hash.encode() + \
316
+ data
317
+ logger.debug(f"send 流水号: {sequence_number}; RANK: {rank}; STEP: {step}; LENGTH: {length}")
318
+
319
+ while True:
320
+ if self.defer is None or self.defer.called:
321
+ self.defer = self.send_large_data(data_meaasge)
322
+ break
323
+ time.sleep(0.01)
324
+
325
+ def send_large_data(self, data):
326
+
327
+ if self.tls:
328
+ self.send_buffer += data
329
+ self.buffer_cnt += 1
330
+ if self.buffer_cnt >= MAX_SENDING_QUEUE_SIZE:
331
+ d = self.file_sender.beginFileTransfer(io.BytesIO(self.send_buffer), self.transport)
332
+ self.send_buffer = b""
333
+ self.buffer_cnt = 0
334
+ else:
335
+ d = None
336
+ else:
337
+ d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport)
338
+ return d
339
+
340
+ def connection_timeout(self):
341
+ if self.factory.num_connections <= 0:
342
+ return
343
+
344
+ self.factory.num_connections -= 1
345
+ logger.debug(f"超时退出{self.transport.addr}, PID {os.getpid()}")
346
+ self.transport.loseConnection()
347
+
348
+ def connectionMade(self):
349
+ self.timeout_call = reactor.callLater(self.TIMEOUT, self.connection_timeout)
350
+ self.is_connected = True
351
+ self.factory.num_connections += 1
352
+ logger.info("successfully connect server")
353
+
354
+ def connectionLost(self, reason):
355
+ self.signal_exit = True
356
+ self.factory.num_connections -= 1
357
+ logger.info(f"Lost connection with server, reason is : {reason}")
358
+
359
+
360
+ class MessageClientFactory(protocol.ClientFactory):
361
+ def __init__(self):
362
+ self.num_connections = 0
363
+
364
+ def clientConnectionFailed(self, connector, reason):
365
+ logger.info(f"Fail to connection with server: {reason.getErrorMessage()}")
366
+ reactor.stop()
367
+
368
+ def clientConnectionLost(self, connector, reason):
369
+ logger.info(f"Client lost connection with server: {reason.getErrorMessage()}")
370
+ reactor.stop()