mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -0,0 +1,202 @@
1
+ import io
2
+ import os.path
3
+ import time
4
+ import re
5
+ from pathlib import Path
6
+ from multiprocessing import Queue
7
+ from typing import Optional, Union, Dict, Any
8
+ from collections import namedtuple
9
+ from dataclasses import dataclass
10
+
11
+ import torch
12
+
13
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
14
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
15
+ from msprobe.pytorch.common.utils import logger
16
+ from msprobe.pytorch.common.utils import save_pt
17
+ from msprobe.core.common.utils import remove_path
18
+
19
+
20
+ ApiData = namedtuple('ApiData', ['name', 'args', 'kwargs', 'result', 'step', 'rank'],
21
+ defaults=['unknown', None, None, None, 0, 0])
22
+ BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
23
+
24
+
25
+ @dataclass
26
+ class ATTLConfig:
27
+ is_benchmark_device: bool
28
+ connect_ip: str
29
+ connect_port: int
30
+ # storage_config
31
+ nfs_path: str = None
32
+ tls_path: str = None
33
+ check_sum: bool = True
34
+ queue_size: int = 50
35
+
36
+
37
+ class ATTL:
38
+ def __init__(self, session_id: str, session_config: ATTLConfig, need_dump=True) -> None:
39
+ self.session_id = session_id
40
+ self.session_config = session_config
41
+ self.logger = logger
42
+ self.socket_manager = None
43
+ self.data_queue = Queue(maxsize=50)
44
+ self.dequeue_list = []
45
+ self.message_end = False
46
+ self.kill_progress = False
47
+ self.check_attl_config()
48
+ if self.session_config.nfs_path:
49
+ self.nfs_path = Path(self.session_config.nfs_path)
50
+ elif self.session_config.is_benchmark_device:
51
+
52
+ self.socket_manager = TCPServer(self.session_config.connect_port,
53
+ self.data_queue,
54
+ self.session_config.check_sum,
55
+ self.session_config.tls_path)
56
+ self.socket_manager.start()
57
+ elif need_dump:
58
+ self.socket_manager = TCPClient(self.session_config.connect_ip,
59
+ self.session_config.connect_port,
60
+ self.session_config.check_sum,
61
+ self.session_config.tls_path)
62
+ self.socket_manager.start()
63
+
64
+ def check_attl_config(self):
65
+ if self.session_config.nfs_path:
66
+ if os.path.exists(self.session_config.nfs_path):
67
+ return
68
+ else:
69
+ raise Exception(f"nfs path {self.session_config.nfs_path} doesn't exists.")
70
+ ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
71
+ if not re.match(ipv4_pattern, self.session_config.connect_ip):
72
+ raise Exception(f"host {self.session_config.connect_ip} is invalid.")
73
+ if not (0 < self.session_config.connect_port <= 65535):
74
+ raise Exception(f"port {self.session_config.connect_port} is invalid.")
75
+
76
+ def stop_serve(self):
77
+ if isinstance(self.socket_manager, TCPServer):
78
+ self.socket_manager.stop()
79
+
80
+ def send(self, buffer: BufferType) -> None:
81
+ """
82
+ npu major in 'send' (client)
83
+ """
84
+ # know receiver receive and go next
85
+ if isinstance(buffer, ApiData):
86
+ buffer = move2target_device(buffer, torch.device('cpu'))
87
+
88
+ if 'device' in buffer.kwargs:
89
+ buffer.kwargs.pop('device')
90
+ rank = buffer.rank if hasattr(buffer, "rank") and buffer.rank is not None else 0
91
+ step = buffer.step if hasattr(buffer, "step") else 0
92
+ io_buff = io.BytesIO()
93
+ try:
94
+ torch.save(buffer, io_buff)
95
+ except Exception as e:
96
+ self.logger.info(f"{buffer.name} can not be saved, skip: {e}")
97
+ return
98
+ data = io_buff.getvalue()
99
+ self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
100
+
101
+ def recv(self, timeout_ms=0) -> Optional[BufferType]:
102
+ buffer = None
103
+ while buffer is None:
104
+ if timeout_ms > 0:
105
+ time.sleep(timeout_ms / 1000.0)
106
+ if buffer is None and not self.data_queue.empty():
107
+ buffer = self.data_queue.get()
108
+ break
109
+ if buffer is None and timeout_ms > 0: # timeout is the only case we give up and return None
110
+ break
111
+ if self.message_end and self.data_queue.empty():
112
+ buffer = b"KILL_CONFIRM"
113
+ self.kill_progress = True
114
+ break
115
+ time.sleep(0.1) # waiting outside the lock before next attempt
116
+ if buffer is None:
117
+ # this is a result of a timeout
118
+ self.logger.info(f"RECEIVE API DATA TIMED OUT")
119
+ else:
120
+ if buffer == b"STOP_":
121
+ return "STOP_"
122
+ if buffer == b"KILL_":
123
+ self.message_end = True
124
+ return "STOP_"
125
+ if buffer == b"KILL_CONFIRM":
126
+ self.kill_progress = True
127
+ return "KILL_"
128
+ buffer = io.BytesIO(buffer)
129
+ try:
130
+ buffer = torch.load(buffer, map_location="cpu")
131
+ except Exception as e:
132
+ self.logger.warning("there is something error. please check it. %s", e)
133
+ if isinstance(buffer, bytes):
134
+ return None
135
+ if isinstance(buffer, str):
136
+ return buffer
137
+
138
+ return buffer
139
+
140
+ def upload(self, buffer: BufferType):
141
+ if isinstance(buffer, ApiData):
142
+ buffer = move2target_device(buffer, torch.device('cpu'))
143
+ file_path = os.path.join(self.session_config.nfs_path, buffer.name + ".pt")
144
+ else:
145
+ file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
146
+
147
+ try:
148
+ save_pt(buffer, file_path)
149
+ except Exception as e:
150
+ self.logger.warning("there is something error in save_pt. please check it. %s", e)
151
+
152
+ def download(self):
153
+ for file_type in ("start*", "*.pt", "end*"):
154
+ cur_file = next(self.nfs_path.glob(file_type), None)
155
+ if cur_file is not None:
156
+ break
157
+
158
+ if cur_file is None:
159
+ return None
160
+ else:
161
+ buffer = None
162
+ try:
163
+ buffer = torch.load(cur_file)
164
+ except Exception as e:
165
+ self.logger.warning("there is something error. please check it. %s", e)
166
+ remove_path(cur_file)
167
+ return buffer
168
+
169
+
170
+ def move2device_exec(obj, device):
171
+ if isinstance(obj, (tuple, list)):
172
+ data_list = [move2device_exec(val, device) for val in obj]
173
+ return data_list if isinstance(obj, list) else tuple(data_list)
174
+ if isinstance(obj, dict):
175
+ return {key: move2device_exec(val, device) for key, val in obj.items()}
176
+ elif isinstance(obj, torch.Tensor):
177
+ obj = obj.detach()
178
+ if obj.device.type != device:
179
+ obj = obj.to(device)
180
+ return obj
181
+ elif "return_types" in str(type(obj)):
182
+ return move2device_exec(tuple(obj), device)
183
+ elif isinstance(obj, torch._C.device):
184
+ return torch.device(device)
185
+ else:
186
+ return obj
187
+
188
+
189
+ def move2target_device(buffer: ApiData, target_device):
190
+ # handle args
191
+ new_args = move2device_exec(buffer.args, target_device)
192
+
193
+ # handle kwargs
194
+ new_kwargs = move2device_exec(buffer.kwargs, target_device)
195
+
196
+ # handle result
197
+ new_results = move2device_exec(buffer.result, target_device)
198
+
199
+ if target_device == torch.device('cpu') or target_device == "cpu":
200
+ return ApiData(buffer.name, tuple(new_args), new_kwargs, new_results, buffer.step, buffer.rank)
201
+ else:
202
+ return ApiData(buffer.name, tuple(new_args), new_kwargs, buffer.result, buffer.step, buffer.rank)
@@ -0,0 +1,324 @@
1
+ import hashlib
2
+ import io
3
+ import struct
4
+ import time
5
+ import os
6
+ import signal
7
+ import sys
8
+ from queue import Queue
9
+ from threading import Thread
10
+ from typing import Union
11
+
12
+ from OpenSSL import SSL
13
+ from twisted.internet import ssl, reactor, protocol, endpoints
14
+ from twisted.protocols.basic import FileSender
15
+
16
+ from msprobe.pytorch.common.utils import logger
17
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.ssl_config import cipher_list
18
+
19
+
20
+ class TCPDataItem:
21
+ def __init__(self, data,
22
+ sequence_number: int,
23
+ rank: int = 0,
24
+ step: int = 0):
25
+ self.raw_data = data
26
+ self.sequence_number = sequence_number
27
+ self.rank = rank
28
+ self.step = step
29
+ self.retry_times = 0
30
+ self.pending_time = 0
31
+ self.busy_time = 0
32
+
33
+
34
+ class TCPClient:
35
+ MAX_SENDING_QUEUE_SIZE = 20
36
+ ACK_SUCCESS = b"OK___"
37
+ ACK_ERROR = b"ERROR"
38
+ ACK_BUSY = b"BUSY_"
39
+ ACK_STOP = b"STOP_"
40
+ ACK_STOP_CONFIRM = b"OVER_"
41
+ ACK_KILL_PROCESS = b"KILL_"
42
+
43
+ QUEUE_PENDING_TIME = 600 # 队列10分钟都处于阻塞状态,则终止sending进程
44
+ RESEND_RETRY_TIMES = 2 # 最大重传数
45
+ RESEND_TIMER_TIME = 5 # 接收ACK超时定时器
46
+ RESEND_PENDING_TIME = 60 # 连续pending时间超过1分钟则放弃该数据
47
+
48
+ def __init__(self, host="localhost", port=8000, check_sum=False, tls_path=None):
49
+ self.send_queue = Queue(self.MAX_SENDING_QUEUE_SIZE)
50
+ self.resend_dict = dict()
51
+ self.host = host
52
+ self.port = port
53
+ self.tls_path = tls_path
54
+ self.factory = None
55
+ self.sequence_number = 0
56
+ self.signal_exit = False
57
+ self.tcp_manager = ClientProtocol(ack_queue_size=100,
58
+ chunk_size=655360,
59
+ check_sum=check_sum)
60
+ self.send_thread = Thread(target=self._sending_queue_data)
61
+ self.send_thread.setDaemon(True)
62
+ self.send_thread.start()
63
+ self.destroy_thread = Thread(target=self._destroy_queue_data)
64
+ self.destroy_thread.setDaemon(True)
65
+ self.destroy_thread.start()
66
+
67
+ @staticmethod
68
+ def run_reactor():
69
+ reactor.run(installSignalHandlers=False)
70
+
71
+ def start(self):
72
+ def conn_callback(cur_protocol):
73
+ if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host:
74
+ logger.debug(f"Process: {os.getpid()} connects to server successfully.")
75
+ else:
76
+ logger.warning(f"Process: {os.getpid()} fails to connect to server. ")
77
+ raise ConnectionError(f"Failed to connect to {self.host}.")
78
+
79
+ def conn_err_callback(failure):
80
+ self.signal_exit = True
81
+ time.sleep(1)
82
+ reactor.stop()
83
+ logger.error(f"Failed to connected {self.host} {self.port}. Reason is {failure.getErrorMessage()}")
84
+ os.kill(os.getpid(), signal.SIGKILL)
85
+ os.kill(os.getppid(), signal.SIGKILL)
86
+
87
+ def cur_protocol():
88
+ return self.tcp_manager
89
+
90
+ self.factory = MessageClientFactory()
91
+ self.factory.protocol = cur_protocol
92
+ if self.tls_path:
93
+ client_key = os.path.join(self.tls_path, "client.key")
94
+ client_crt = os.path.join(self.tls_path, "client.crt")
95
+ client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt, SSL.TLSv1_2_METHOD)
96
+ client_context_ = client_context_factory.getContext()
97
+ client_context_.set_cipher_list(cipher_list)
98
+ client_context_.set_options(SSL.OP_NO_RENEGOTIATION)
99
+ endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory)
100
+ else:
101
+ endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port)
102
+ d = endpoint.connect(self.factory)
103
+ d.addCallback(conn_callback)
104
+ d.addErrback(conn_err_callback)
105
+
106
+ reactor_thread = Thread(target=self.run_reactor, daemon=True)
107
+ reactor_thread.start()
108
+
109
+ def send_after_queue_empty(self, data):
110
+ while not self._ready_to_exit():
111
+ self.add_to_sending_queue(data)
112
+ time.sleep(2)
113
+
114
+ def check_client_alive(self):
115
+ return self.factory.num_connections > 0
116
+
117
+ def stop(self):
118
+ self.tcp_manager.connection_timeout()
119
+
120
+ def send_stop_signal(self):
121
+ self.send_after_queue_empty(self.ACK_STOP)
122
+ while not self._ready_to_exit():
123
+ if not self.check_client_alive():
124
+ break
125
+ time.sleep(1)
126
+ while not self.tcp_manager.kill_process:
127
+ time.sleep(1)
128
+
129
+ def add_to_sending_queue(self, data: Union[bytes, TCPDataItem], rank: int = 0, step: int = 0):
130
+ if self._ready_to_exit():
131
+ return
132
+
133
+ send_data = data
134
+ if not isinstance(data, TCPDataItem):
135
+ send_data = TCPDataItem(data=data,
136
+ sequence_number=self.sequence_number,
137
+ rank=rank,
138
+ step=step)
139
+ self.sequence_number += 1
140
+ try:
141
+ self.send_queue.put(send_data, block=True, timeout=self.QUEUE_PENDING_TIME)
142
+ except Exception as e:
143
+ logger.error(f"send_queue put send_data timeout, rank: {send_data.rank}, step: {send_data.step},"
144
+ f"sequence_number: {send_data.sequence_number}, {str(e)}")
145
+
146
+ def _send_data(self, data: TCPDataItem):
147
+ self.tcp_manager.send_wrapped_data(data.raw_data,
148
+ sequence_number=data.sequence_number,
149
+ rank=data.rank,
150
+ step=data.step
151
+ )
152
+
153
+ def _sending_queue_data(self):
154
+ while True:
155
+ if not self.tcp_manager.is_connected:
156
+ continue
157
+
158
+ while self.send_queue.qsize() > 0:
159
+ if self._ready_to_exit():
160
+ break
161
+ if len(self.resend_dict) < self.MAX_SENDING_QUEUE_SIZE:
162
+ data_obj = self.send_queue.get()
163
+ self._send_data(data_obj)
164
+ resend_key = str(data_obj.sequence_number) + "_" + str(data_obj.rank) + "_" + str(data_obj.step)
165
+ if resend_key not in self.resend_dict.keys():
166
+ # Send data for the first time
167
+ self.resend_dict[resend_key] = data_obj
168
+ else:
169
+ time.sleep(0.1)
170
+
171
+ if self._ready_to_exit():
172
+ logger.debug("Successfully close sending process.")
173
+ break
174
+ time.sleep(0.1)
175
+
176
+ def _destroy_queue_data(self):
177
+ while True:
178
+ if self._ready_to_exit():
179
+ break
180
+
181
+ while len(self.resend_dict) > 0 and self.tcp_manager.ack_queue.qsize() > 0:
182
+ ack_info, seq_number, rank, step = self.tcp_manager.ack_queue.get()
183
+ obj_key = str(seq_number) + "_" + str(rank) + "_" + str(step)
184
+ current_item = self.resend_dict.get(obj_key)
185
+
186
+ if current_item is None:
187
+ continue
188
+
189
+ if ack_info == self.ACK_SUCCESS:
190
+ self.resend_dict.pop(obj_key)
191
+ elif ack_info == self.ACK_BUSY:
192
+ logger.debug("RECV BUSY ACK")
193
+ if current_item.busy_time > 5:
194
+ self._resend_data(current_item)
195
+ else:
196
+ current_item.busy_time += 1
197
+ elif ack_info == self.ACK_ERROR:
198
+ logger.debug("RECV ERROR ACK")
199
+ self._resend_data(current_item)
200
+ elif ack_info == self.ACK_STOP_CONFIRM:
201
+ logger.debug("RECV STOP ACK")
202
+ self.factory.num_connections -= 1
203
+
204
+ break
205
+
206
+ time.sleep(0.1)
207
+
208
+ def _resend_data(self, data: TCPDataItem):
209
+ if data.retry_times < self.RESEND_RETRY_TIMES:
210
+ data.retry_times += 1
211
+ logger.debug(f"Resend data seq number: {data.sequence_number}")
212
+ self.add_to_sending_queue(data)
213
+ else:
214
+ self.resend_dict.pop(data.sequence_number)
215
+ logger.debug(f"SKIP send sequence number {data.sequence_number} after retry {data.retry_times} times!")
216
+
217
+ def _pending_data(self, data: TCPDataItem):
218
+ if data.pending_time >= self.RESEND_PENDING_TIME:
219
+ self.resend_dict.pop(data.sequence_number)
220
+ logger.debug(f"SKIP send sequence number {data.sequence_number} after pending {data.pending_time} times!")
221
+ return
222
+
223
+ # wait time is 100MB per second
224
+ pending_time = max(1, len(data.raw_data) // (2 ** 20 * 50))
225
+ data.pending_time += pending_time
226
+ time.sleep(pending_time)
227
+
228
+ def _ready_to_exit(self):
229
+ return self.signal_exit or self.tcp_manager.signal_exit
230
+
231
+
232
+ class ClientProtocol(protocol.Protocol):
233
+ TIMEOUT = 60 * 10
234
+
235
+ def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False):
236
+ self.buffer = io.BytesIO()
237
+ self.is_connected = False
238
+ self.check_sum = check_sum
239
+ self.tell = 0
240
+ self.ack_queue = Queue(maxsize=ack_queue_size)
241
+ self.file_sender = FileSender()
242
+ self.file_sender.CHUNK_SIZE = chunk_size
243
+ self.signal_exit = False
244
+ self.defer = None
245
+ self.kill_process = False
246
+
247
+ def dataReceived(self, data):
248
+ if self.timeout_call.active():
249
+ self.timeout_call.reset(self.TIMEOUT)
250
+
251
+ self.buffer.seek(0, 2)
252
+ self.buffer.write(data)
253
+ self.buffer.seek(self.tell)
254
+ while True:
255
+ if len(self.buffer.getvalue()) >= 29: # 5 + 8 * 3
256
+ ack = self.buffer.read(5)
257
+ seq_number = struct.unpack('!Q', self.buffer.read(8))[0]
258
+ rank = struct.unpack('!Q', self.buffer.read(8))[0]
259
+ step = struct.unpack('!Q', self.buffer.read(8))[0]
260
+ if ack == b"KILL_":
261
+ self.kill_process = True
262
+ logger.debug(f"接收到KILL信号, PID {os.getpid()}")
263
+ if ack == b"OVER_":
264
+ self.factory.num_connections -= 1
265
+ self.tell += 29
266
+ if not self.ack_queue.full():
267
+ self.ack_queue.put((ack, seq_number, rank, step))
268
+ self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:])
269
+ self.tell = 0
270
+ else:
271
+ time.sleep(0.1)
272
+ else:
273
+ break
274
+
275
+ def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0):
276
+ length = len(data)
277
+ md5_hash = hashlib.md5(data).hexdigest() if self.check_sum else ""
278
+ while True:
279
+ if self.defer is None or self.defer.called:
280
+ self.defer = self.send_large_data(
281
+ length.to_bytes(8, byteorder='big') +
282
+ sequence_number.to_bytes(8, byteorder='big') +
283
+ rank.to_bytes(8, byteorder='big') +
284
+ step.to_bytes(8, byteorder='big') +
285
+ md5_hash.encode() +
286
+ data)
287
+ break
288
+ time.sleep(0.01)
289
+
290
+ def send_large_data(self, data):
291
+ d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport)
292
+ return d
293
+
294
+ def connection_timeout(self):
295
+ if self.factory.num_connections <= 0:
296
+ return
297
+
298
+ self.factory.num_connections -= 1
299
+ logger.debug(f"超时退出{self.transport.addr}, PID {os.getpid()}")
300
+ self.transport.loseConnection()
301
+
302
+ def connectionMade(self):
303
+ self.timeout_call = reactor.callLater(self.TIMEOUT, self.connection_timeout)
304
+ self.is_connected = True
305
+ self.factory.num_connections += 1
306
+ logger.info("successfully connect server")
307
+
308
+ def connectionLost(self, reason):
309
+ self.signal_exit = True
310
+ self.factory.num_connections -= 1
311
+ logger.info(f"Lost connection with server, reason is : {reason}")
312
+
313
+
314
+ class MessageClientFactory(protocol.ClientFactory):
315
+ def __init__(self):
316
+ self.num_connections = 0
317
+
318
+ def clientConnectionFailed(self, connector, reason):
319
+ logger.info(f"Fail to connection with server: {reason.getErrorMessage()}")
320
+ reactor.stop()
321
+
322
+ def clientConnectionLost(self, connector, reason):
323
+ logger.info(f"Client lost connection with server: {reason.getErrorMessage()}")
324
+ reactor.stop()