mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,7 +1,70 @@
1
- hf_32_standard_api = ["conv1d", "conv2d"]
2
-
3
-
4
- class Backward_Message:
5
- MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
6
- UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, skip backward."
7
- NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
1
+ import os
2
+ import re
3
+
4
+ from msprobe.core.common.const import FileCheckConst
5
+ from msprobe.core.common.file_utils import FileChecker
6
+ from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
7
+ from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
8
+ from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
9
+ from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
10
+ from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
11
+
12
+ hf_32_standard_api = ["conv1d", "conv2d"]
13
+
14
+
15
+ class Backward_Message:
16
+ MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
17
+ UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, skip backward."
18
+ NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
19
+
20
+
21
+ class UtDataInfo:
22
+ def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list,
23
+ backward_message, rank=0):
24
+ self.bench_grad = bench_grad
25
+ self.device_grad = device_grad
26
+ self.device_output = device_output
27
+ self.bench_output = bench_output
28
+ self.grad_in = grad_in
29
+ self.in_fwd_data_list = in_fwd_data_list
30
+ self.backward_message = backward_message
31
+ self.rank = rank
32
+
33
+
34
+ def get_validated_result_csv_path(result_csv_path, mode):
35
+ if mode not in ['result', 'detail']:
36
+ raise ValueError("The csv mode must be result or detail")
37
+ result_csv_path_checker = FileChecker(result_csv_path, FileCheckConst.FILE, ability=FileCheckConst.READ_WRITE_ABLE,
38
+ file_type=FileCheckConst.CSV_SUFFIX)
39
+ validated_result_csv_path = result_csv_path_checker.common_check()
40
+ if mode == 'result':
41
+ result_csv_name = os.path.basename(validated_result_csv_path)
42
+ pattern = r"^accuracy_checking_result_\d{14}\.csv$"
43
+ if not re.match(pattern, result_csv_name):
44
+ raise ValueError("When continue run ut, please do not modify the result csv name.")
45
+ return validated_result_csv_path
46
+
47
+
48
+ def get_validated_details_csv_path(validated_result_csv_path):
49
+ result_csv_name = os.path.basename(validated_result_csv_path)
50
+ details_csv_name = result_csv_name.replace('result', 'details')
51
+ details_csv_path = os.path.join(os.path.dirname(validated_result_csv_path), details_csv_name)
52
+ details_csv_path_checker = FileChecker(details_csv_path, FileCheckConst.FILE,
53
+ ability=FileCheckConst.READ_WRITE_ABLE, file_type=FileCheckConst.CSV_SUFFIX)
54
+ validated_details_csv_path = details_csv_path_checker.common_check()
55
+ return validated_details_csv_path
56
+
57
+
58
+ def exec_api(api_type, api_name, device, args, kwargs):
59
+ if api_type == "Functional":
60
+ torch_api = FunctionalOPTemplate(api_name, str, False)
61
+ if api_type == "Tensor":
62
+ torch_api = TensorOPTemplate(api_name, str, False)
63
+ if api_type == "Torch":
64
+ torch_api = TorchOPTemplate(api_name, str, False)
65
+ if api_type == "Aten":
66
+ torch_api = AtenOPTemplate(api_name, None, False)
67
+ if api_type == "NPU":
68
+ torch_api = NpuOPTemplate(api_name, None, False, device)
69
+ out = torch_api.forward(*args, **kwargs)
70
+ return out
@@ -1,5 +1,8 @@
1
- {
2
- "topk": {
3
- "grad_index": 0
4
- }
1
+ {
2
+ "topk": {
3
+ "grad_index": 0
4
+ },
5
+ "npu_fusion_attention": {
6
+ "grad_index": 0
7
+ }
5
8
  }
@@ -0,0 +1,197 @@
1
+ import glob
2
+ import os.path
3
+ import time
4
+ import re
5
+ from multiprocessing import Queue
6
+ from typing import Optional, Union, Dict, Any
7
+ from dataclasses import dataclass
8
+
9
+ import torch
10
+
11
+ from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
12
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
13
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
14
+ from msprobe.pytorch.common.utils import logger
15
+ from msprobe.core.common.file_utils import remove_path
16
+ from msprobe.pytorch.common.utils import save_api_data, load_api_data, save_pt, load_pt
17
+
18
+ BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
19
+
20
+
21
+ @dataclass
22
+ class ATTLConfig:
23
+ is_benchmark_device: bool
24
+ connect_ip: str
25
+ connect_port: int
26
+ # storage_config
27
+ nfs_path: str = None
28
+ tls_path: str = None
29
+ check_sum: bool = True
30
+ queue_size: int = 50
31
+
32
+
33
+ class ATTL:
34
+ def __init__(self, session_id: str, session_config: ATTLConfig, need_dump=True) -> None:
35
+ self.session_id = session_id
36
+ self.session_config = session_config
37
+ self.logger = logger
38
+ self.socket_manager = None
39
+ self.data_queue = Queue(maxsize=50)
40
+ self.dequeue_list = []
41
+ self.message_end = False
42
+ self.kill_progress = False
43
+ self.check_attl_config()
44
+ if self.session_config.nfs_path:
45
+ self.nfs_path = self.session_config.nfs_path
46
+ elif self.session_config.is_benchmark_device:
47
+
48
+ self.socket_manager = TCPServer(self.session_config.connect_port,
49
+ self.data_queue,
50
+ self.session_config.check_sum,
51
+ self.session_config.tls_path)
52
+ self.socket_manager.start()
53
+ elif need_dump:
54
+ self.socket_manager = TCPClient(self.session_config.connect_ip,
55
+ self.session_config.connect_port,
56
+ self.session_config.check_sum,
57
+ self.session_config.tls_path)
58
+ self.socket_manager.start()
59
+
60
+ def check_attl_config(self):
61
+ if self.session_config.nfs_path:
62
+ if os.path.exists(self.session_config.nfs_path):
63
+ return
64
+ else:
65
+ raise Exception(f"nfs path {self.session_config.nfs_path} doesn't exists.")
66
+ ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
67
+ if not re.match(ipv4_pattern, self.session_config.connect_ip):
68
+ raise Exception(f"host {self.session_config.connect_ip} is invalid.")
69
+ if not (0 < self.session_config.connect_port <= 65535):
70
+ raise Exception(f"port {self.session_config.connect_port} is invalid.")
71
+
72
+ def stop_serve(self):
73
+ if isinstance(self.socket_manager, TCPServer):
74
+ self.socket_manager.stop()
75
+
76
+ def send(self, buffer: BufferType) -> None:
77
+ """
78
+ npu major in 'send' (client)
79
+ """
80
+ # know receiver receive and go next
81
+ if isinstance(buffer, ApiData):
82
+ buffer = move2target_device(buffer, torch.device('cpu'))
83
+
84
+ if 'device' in buffer.kwargs:
85
+ buffer.kwargs.pop('device')
86
+ rank = buffer.rank if hasattr(buffer, "rank") and buffer.rank is not None else 0
87
+ step = buffer.step if hasattr(buffer, "step") else 0
88
+ try:
89
+ io_buff = save_api_data(buffer)
90
+ except Exception as e:
91
+ self.logger.info(f"{buffer.name} can not be saved, skip: {e}")
92
+ return
93
+ data = io_buff.getvalue()
94
+ self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
95
+
96
+ def recv(self, timeout_ms=0) -> Optional[BufferType]:
97
+ buffer = None
98
+ while buffer is None:
99
+ if timeout_ms > 0:
100
+ time.sleep(timeout_ms / 1000.0)
101
+ if buffer is None and not self.data_queue.empty():
102
+ buffer = self.data_queue.get()
103
+ break
104
+ if buffer is None and timeout_ms > 0: # timeout is the only case we give up and return None
105
+ break
106
+ if self.message_end and self.data_queue.empty():
107
+ buffer = b"KILL_CONFIRM"
108
+ self.kill_progress = True
109
+ break
110
+ time.sleep(0.1) # waiting outside the lock before next attempt
111
+ if buffer is None:
112
+ # this is a result of a timeout
113
+ self.logger.info(f"RECEIVE API DATA TIMED OUT")
114
+ else:
115
+ if buffer == b"STOP_":
116
+ return "STOP_"
117
+ if buffer == b"KILL_":
118
+ self.message_end = True
119
+ return "STOP_"
120
+ if buffer == b"KILL_CONFIRM":
121
+ self.kill_progress = True
122
+ return "KILL_"
123
+ try:
124
+ buffer = load_api_data(buffer)
125
+ except Exception as e:
126
+ self.logger.warning("there is something error. please check it. %s", e)
127
+ if isinstance(buffer, bytes):
128
+ return None
129
+ if isinstance(buffer, str):
130
+ return buffer
131
+
132
+ return buffer
133
+
134
+ def upload(self, buffer: BufferType):
135
+ if isinstance(buffer, ApiData):
136
+ buffer = move2target_device(buffer, torch.device('cpu'))
137
+ file_path = os.path.join(self.session_config.nfs_path, buffer.name + ".pt")
138
+ else:
139
+ file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
140
+
141
+ try:
142
+ save_pt(buffer, file_path)
143
+ except Exception as e:
144
+ self.logger.warning("there is something error in save_pt. please check it. %s", e)
145
+
146
+ def download(self):
147
+ buffer = None
148
+ cur_file = None
149
+ for file_type in ("start*", "*.pt", "end*"):
150
+ pattern = os.path.join(self.nfs_path, file_type)
151
+ files = glob.glob(pattern)
152
+ if len(files) > 0:
153
+ cur_file = files[0]
154
+ break
155
+
156
+ if cur_file is not None:
157
+ try:
158
+ buffer = load_pt(cur_file)
159
+ except Exception as e:
160
+ self.logger.warning("there is something error. please check it. %s", e)
161
+ remove_path(cur_file)
162
+ return buffer
163
+
164
+
165
+ def move2device_exec(obj, device):
166
+ if isinstance(obj, (tuple, list)):
167
+ data_list = [move2device_exec(val, device) for val in obj]
168
+ return data_list if isinstance(obj, list) else tuple(data_list)
169
+ if isinstance(obj, dict):
170
+ return {key: move2device_exec(val, device) for key, val in obj.items()}
171
+ elif isinstance(obj, torch.Tensor):
172
+ obj = obj.detach()
173
+ if obj.device.type != device:
174
+ obj = obj.to(device)
175
+ return obj
176
+ elif "return_types" in str(type(obj)):
177
+ return move2device_exec(tuple(obj), device)
178
+ elif isinstance(obj, torch._C.device):
179
+ return torch.device(device)
180
+ else:
181
+ return obj
182
+
183
+
184
+ def move2target_device(buffer: ApiData, target_device):
185
+ # handle args
186
+ new_args = move2device_exec(buffer.args, target_device)
187
+
188
+ # handle kwargs
189
+ new_kwargs = move2device_exec(buffer.kwargs, target_device)
190
+
191
+ # handle result
192
+ new_results = move2device_exec(buffer.result, target_device)
193
+
194
+ if target_device == torch.device('cpu') or target_device == "cpu":
195
+ return ApiData(buffer.name, tuple(new_args), new_kwargs, new_results, buffer.step, buffer.rank)
196
+ else:
197
+ return ApiData(buffer.name, tuple(new_args), new_kwargs, buffer.result, buffer.step, buffer.rank)
@@ -0,0 +1,325 @@
1
+ import hashlib
2
+ import io
3
+ import struct
4
+ import time
5
+ import os
6
+ import signal
7
+ import sys
8
+ from queue import Queue
9
+ from threading import Thread
10
+ from typing import Union
11
+
12
+ from twisted.internet import reactor, protocol, endpoints
13
+ from twisted.protocols.basic import FileSender
14
+
15
+ from msprobe.pytorch.common.utils import logger
16
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.ssl_config import cipher_list
17
+
18
+
19
+ class TCPDataItem:
20
+ def __init__(self, data,
21
+ sequence_number: int,
22
+ rank: int = 0,
23
+ step: int = 0):
24
+ self.raw_data = data
25
+ self.sequence_number = sequence_number
26
+ self.rank = rank
27
+ self.step = step
28
+ self.retry_times = 0
29
+ self.pending_time = 0
30
+ self.busy_time = 0
31
+
32
+
33
+ class TCPClient:
34
+ MAX_SENDING_QUEUE_SIZE = 20
35
+ ACK_SUCCESS = b"OK___"
36
+ ACK_ERROR = b"ERROR"
37
+ ACK_BUSY = b"BUSY_"
38
+ ACK_STOP = b"STOP_"
39
+ ACK_STOP_CONFIRM = b"OVER_"
40
+ ACK_KILL_PROCESS = b"KILL_"
41
+
42
+ QUEUE_PENDING_TIME = 600 # 队列10分钟都处于阻塞状态,则终止sending进程
43
+ RESEND_RETRY_TIMES = 2 # 最大重传数
44
+ RESEND_TIMER_TIME = 5 # 接收ACK超时定时器
45
+ RESEND_PENDING_TIME = 60 # 连续pending时间超过1分钟则放弃该数据
46
+
47
+ def __init__(self, host="localhost", port=8000, check_sum=False, tls_path=None):
48
+ self.send_queue = Queue(self.MAX_SENDING_QUEUE_SIZE)
49
+ self.resend_dict = dict()
50
+ self.host = host
51
+ self.port = port
52
+ self.tls_path = tls_path
53
+ self.factory = None
54
+ self.sequence_number = 0
55
+ self.signal_exit = False
56
+ self.tcp_manager = ClientProtocol(ack_queue_size=100,
57
+ chunk_size=655360,
58
+ check_sum=check_sum)
59
+ self.send_thread = Thread(target=self._sending_queue_data)
60
+ self.send_thread.setDaemon(True)
61
+ self.send_thread.start()
62
+ self.destroy_thread = Thread(target=self._destroy_queue_data)
63
+ self.destroy_thread.setDaemon(True)
64
+ self.destroy_thread.start()
65
+
66
+ @staticmethod
67
+ def run_reactor():
68
+ reactor.run(installSignalHandlers=False)
69
+
70
+ def start(self):
71
+ def conn_callback(cur_protocol):
72
+ if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host:
73
+ logger.debug(f"Process: {os.getpid()} connects to server successfully.")
74
+ else:
75
+ logger.warning(f"Process: {os.getpid()} fails to connect to server. ")
76
+ raise ConnectionError(f"Failed to connect to {self.host}.")
77
+
78
+ def conn_err_callback(failure):
79
+ self.signal_exit = True
80
+ time.sleep(1)
81
+ reactor.stop()
82
+ logger.error(f"Failed to connected {self.host} {self.port}. Reason is {failure.getErrorMessage()}")
83
+ os.kill(os.getpid(), signal.SIGKILL)
84
+ os.kill(os.getppid(), signal.SIGKILL)
85
+
86
+ def cur_protocol():
87
+ return self.tcp_manager
88
+
89
+ self.factory = MessageClientFactory()
90
+ self.factory.protocol = cur_protocol
91
+ if self.tls_path:
92
+ from OpenSSL import SSL
93
+ from twisted.internet import ssl
94
+ client_key = os.path.join(self.tls_path, "client.key")
95
+ client_crt = os.path.join(self.tls_path, "client.crt")
96
+ client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt, SSL.TLSv1_2_METHOD)
97
+ client_context_ = client_context_factory.getContext()
98
+ client_context_.set_cipher_list(cipher_list)
99
+ client_context_.set_options(SSL.OP_NO_RENEGOTIATION)
100
+ endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory)
101
+ else:
102
+ endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port)
103
+ d = endpoint.connect(self.factory)
104
+ d.addCallback(conn_callback)
105
+ d.addErrback(conn_err_callback)
106
+
107
+ reactor_thread = Thread(target=self.run_reactor, daemon=True)
108
+ reactor_thread.start()
109
+
110
+ def send_after_queue_empty(self, data):
111
+ while not self._ready_to_exit():
112
+ self.add_to_sending_queue(data)
113
+ time.sleep(2)
114
+
115
+ def check_client_alive(self):
116
+ return self.factory.num_connections > 0
117
+
118
+ def stop(self):
119
+ self.tcp_manager.connection_timeout()
120
+
121
+ def send_stop_signal(self):
122
+ self.send_after_queue_empty(self.ACK_STOP)
123
+ while not self._ready_to_exit():
124
+ if not self.check_client_alive():
125
+ break
126
+ time.sleep(1)
127
+ while not self.tcp_manager.kill_process:
128
+ time.sleep(1)
129
+
130
+ def add_to_sending_queue(self, data: Union[bytes, TCPDataItem], rank: int = 0, step: int = 0):
131
+ if self._ready_to_exit():
132
+ return
133
+
134
+ send_data = data
135
+ if not isinstance(data, TCPDataItem):
136
+ send_data = TCPDataItem(data=data,
137
+ sequence_number=self.sequence_number,
138
+ rank=rank,
139
+ step=step)
140
+ self.sequence_number += 1
141
+ try:
142
+ self.send_queue.put(send_data, block=True, timeout=self.QUEUE_PENDING_TIME)
143
+ except Exception as e:
144
+ logger.error(f"send_queue put send_data timeout, rank: {send_data.rank}, step: {send_data.step},"
145
+ f"sequence_number: {send_data.sequence_number}, {str(e)}")
146
+
147
+ def _send_data(self, data: TCPDataItem):
148
+ self.tcp_manager.send_wrapped_data(data.raw_data,
149
+ sequence_number=data.sequence_number,
150
+ rank=data.rank,
151
+ step=data.step
152
+ )
153
+
154
+ def _sending_queue_data(self):
155
+ while True:
156
+ if not self.tcp_manager.is_connected:
157
+ continue
158
+
159
+ while self.send_queue.qsize() > 0:
160
+ if self._ready_to_exit():
161
+ break
162
+ if len(self.resend_dict) < self.MAX_SENDING_QUEUE_SIZE:
163
+ data_obj = self.send_queue.get()
164
+ self._send_data(data_obj)
165
+ resend_key = str(data_obj.sequence_number) + "_" + str(data_obj.rank) + "_" + str(data_obj.step)
166
+ if resend_key not in self.resend_dict.keys():
167
+ # Send data for the first time
168
+ self.resend_dict[resend_key] = data_obj
169
+ else:
170
+ time.sleep(0.1)
171
+
172
+ if self._ready_to_exit():
173
+ logger.debug("Successfully close sending process.")
174
+ break
175
+ time.sleep(0.1)
176
+
177
+ def _destroy_queue_data(self):
178
+ while True:
179
+ if self._ready_to_exit():
180
+ break
181
+
182
+ while len(self.resend_dict) > 0 and self.tcp_manager.ack_queue.qsize() > 0:
183
+ ack_info, seq_number, rank, step = self.tcp_manager.ack_queue.get()
184
+ obj_key = str(seq_number) + "_" + str(rank) + "_" + str(step)
185
+ current_item = self.resend_dict.get(obj_key)
186
+
187
+ if current_item is None:
188
+ continue
189
+
190
+ if ack_info == self.ACK_SUCCESS:
191
+ self.resend_dict.pop(obj_key)
192
+ elif ack_info == self.ACK_BUSY:
193
+ logger.debug("RECV BUSY ACK")
194
+ if current_item.busy_time > 5:
195
+ self._resend_data(current_item)
196
+ else:
197
+ current_item.busy_time += 1
198
+ elif ack_info == self.ACK_ERROR:
199
+ logger.debug("RECV ERROR ACK")
200
+ self._resend_data(current_item)
201
+ elif ack_info == self.ACK_STOP_CONFIRM:
202
+ logger.debug("RECV STOP ACK")
203
+ self.factory.num_connections -= 1
204
+
205
+ break
206
+
207
+ time.sleep(0.1)
208
+
209
+ def _resend_data(self, data: TCPDataItem):
210
+ if data.retry_times < self.RESEND_RETRY_TIMES:
211
+ data.retry_times += 1
212
+ logger.debug(f"Resend data seq number: {data.sequence_number}")
213
+ self.add_to_sending_queue(data)
214
+ else:
215
+ self.resend_dict.pop(data.sequence_number)
216
+ logger.debug(f"SKIP send sequence number {data.sequence_number} after retry {data.retry_times} times!")
217
+
218
+ def _pending_data(self, data: TCPDataItem):
219
+ if data.pending_time >= self.RESEND_PENDING_TIME:
220
+ self.resend_dict.pop(data.sequence_number)
221
+ logger.debug(f"SKIP send sequence number {data.sequence_number} after pending {data.pending_time} times!")
222
+ return
223
+
224
+ # wait time is 100MB per second
225
+ pending_time = max(1, len(data.raw_data) // (2 ** 20 * 50))
226
+ data.pending_time += pending_time
227
+ time.sleep(pending_time)
228
+
229
+ def _ready_to_exit(self):
230
+ return self.signal_exit or self.tcp_manager.signal_exit
231
+
232
+
233
+ class ClientProtocol(protocol.Protocol):
234
+ TIMEOUT = 60 * 10
235
+
236
+ def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False):
237
+ self.buffer = io.BytesIO()
238
+ self.is_connected = False
239
+ self.check_sum = check_sum
240
+ self.tell = 0
241
+ self.ack_queue = Queue(maxsize=ack_queue_size)
242
+ self.file_sender = FileSender()
243
+ self.file_sender.CHUNK_SIZE = chunk_size
244
+ self.signal_exit = False
245
+ self.defer = None
246
+ self.kill_process = False
247
+
248
+ def dataReceived(self, data):
249
+ if self.timeout_call.active():
250
+ self.timeout_call.reset(self.TIMEOUT)
251
+
252
+ self.buffer.seek(0, 2)
253
+ self.buffer.write(data)
254
+ self.buffer.seek(self.tell)
255
+ while True:
256
+ if len(self.buffer.getvalue()) >= 29: # 5 + 8 * 3
257
+ ack = self.buffer.read(5)
258
+ seq_number = struct.unpack('!Q', self.buffer.read(8))[0]
259
+ rank = struct.unpack('!Q', self.buffer.read(8))[0]
260
+ step = struct.unpack('!Q', self.buffer.read(8))[0]
261
+ if ack == b"KILL_":
262
+ self.kill_process = True
263
+ logger.debug(f"接收到KILL信号, PID {os.getpid()}")
264
+ if ack == b"OVER_":
265
+ self.factory.num_connections -= 1
266
+ self.tell += 29
267
+ if not self.ack_queue.full():
268
+ self.ack_queue.put((ack, seq_number, rank, step))
269
+ self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:])
270
+ self.tell = 0
271
+ else:
272
+ time.sleep(0.1)
273
+ else:
274
+ break
275
+
276
+ def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0):
277
+ length = len(data)
278
+ md5_hash = hashlib.md5(data).hexdigest() if self.check_sum else ""
279
+ while True:
280
+ if self.defer is None or self.defer.called:
281
+ self.defer = self.send_large_data(
282
+ length.to_bytes(8, byteorder='big') +
283
+ sequence_number.to_bytes(8, byteorder='big') +
284
+ rank.to_bytes(8, byteorder='big') +
285
+ step.to_bytes(8, byteorder='big') +
286
+ md5_hash.encode() +
287
+ data)
288
+ break
289
+ time.sleep(0.01)
290
+
291
+ def send_large_data(self, data):
292
+ d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport)
293
+ return d
294
+
295
+ def connection_timeout(self):
296
+ if self.factory.num_connections <= 0:
297
+ return
298
+
299
+ self.factory.num_connections -= 1
300
+ logger.debug(f"超时退出{self.transport.addr}, PID {os.getpid()}")
301
+ self.transport.loseConnection()
302
+
303
+ def connectionMade(self):
304
+ self.timeout_call = reactor.callLater(self.TIMEOUT, self.connection_timeout)
305
+ self.is_connected = True
306
+ self.factory.num_connections += 1
307
+ logger.info("successfully connect server")
308
+
309
+ def connectionLost(self, reason):
310
+ self.signal_exit = True
311
+ self.factory.num_connections -= 1
312
+ logger.info(f"Lost connection with server, reason is : {reason}")
313
+
314
+
315
+ class MessageClientFactory(protocol.ClientFactory):
316
+ def __init__(self):
317
+ self.num_connections = 0
318
+
319
+ def clientConnectionFailed(self, connector, reason):
320
+ logger.info(f"Fail to connection with server: {reason.getErrorMessage()}")
321
+ reactor.stop()
322
+
323
+ def clientConnectionLost(self, connector, reason):
324
+ logger.info(f"Client lost connection with server: {reason.getErrorMessage()}")
325
+ reactor.stop()