mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -1,130 +0,0 @@
1
- import unittest
2
- from msprobe.pytorch.hook_module.api_registry import ApiRegistry, torch_version_above_2, is_gpu
3
-
4
-
5
- class TestApiRegistry(unittest.TestCase):
6
-
7
- def test_store_ori_attr(self):
8
- class A():
9
- a1 = 1
10
- class B():
11
- a = A()
12
- b1 = 1
13
- b2 = 2
14
-
15
- api_list = ["a.a1", "b1", "b2"]
16
- expect_output = {"a.a1":1, "b1":1, "b2":2}
17
- actual_output = dict()
18
- ApiRegistry.store_ori_attr(B, api_list, actual_output)
19
- self.assertEqual(actual_output, expect_output)
20
-
21
-
22
- def test_set_api_attr(self):
23
- class A():
24
- a1 = 1
25
- class B():
26
- a = A().__class__
27
- b1 = 1
28
-
29
- attr_dict = {"a.a2":2, "b2":2, "b3":3}
30
- ApiRegistry.set_api_attr(B, attr_dict)
31
-
32
- for k, v in attr_dict.items():
33
- if '.' in k:
34
- sub_module_name, sub_op = k.rsplit('.', 1)
35
- sub_module = getattr(B, sub_module_name, None)
36
-
37
- self.assertEqual(getattr(sub_module, sub_op), v)
38
- else:
39
- self.assertEqual(getattr(B, k), v)
40
-
41
- def test_api_modularity(self):
42
-
43
- import torch
44
- import torch.distributed as dist
45
- #import torch_npu #门禁没有安装torch_npu
46
- from msprobe.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2
47
-
48
-
49
-
50
- reg = ApiRegistry()
51
- attr_dict = {"b2":2, "b3":3}
52
- reg.tensor_hook_attr = attr_dict
53
- reg.torch_hook_attr = attr_dict
54
- reg.functional_hook_attr = attr_dict
55
- reg.distributed_hook_attr = attr_dict
56
- reg.npu_distributed_hook_attr = attr_dict
57
- reg.aten_hook_attr = attr_dict
58
- reg.vf_hook_attr = attr_dict
59
- reg.torch_npu_hook_attr = attr_dict
60
-
61
- reg.api_modularity()
62
- self.assertEqual(torch.Tensor.b2, 2)
63
-
64
- self.assertEqual(torch.b2, 2)
65
- self.assertEqual(torch.nn.functional.b2, 2)
66
- self.assertEqual(dist.b2, 2)
67
- self.assertEqual(dist.distributed_c10d.b2, 2)
68
- #if not is_gpu and not torch_without_guard_version:
69
- #self.assertEqual(torch_npu.distributed.b2, 2)
70
- #self.assertEqual(torch_npu.distributed.distributed_c10d.b2, 2)
71
- if torch_version_above_2:
72
- self.assertEqual(torch.ops.aten.b2, 2)
73
- self.assertEqual(torch._VF.b2, 2)
74
- #if not is_gpu:
75
- #self.assertEqual(torch_npu.b2, 2)
76
-
77
-
78
- def test_api_originality(self):
79
- import torch
80
- import torch.distributed as dist
81
- #import torch_npu #门禁没有安装torch_npu
82
- from msprobe.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2
83
-
84
-
85
-
86
- reg = ApiRegistry()
87
- attr_dict = {"b2":2, "b3":3}
88
- reg.tensor_hook_attr = attr_dict
89
- reg.torch_hook_attr = attr_dict
90
- reg.functional_hook_attr = attr_dict
91
- reg.distributed_hook_attr = attr_dict
92
- reg.npu_distributed_hook_attr = attr_dict
93
- reg.aten_hook_attr = attr_dict
94
- reg.vf_hook_attr = attr_dict
95
- reg.torch_npu_hook_attr = attr_dict
96
-
97
- reg.api_originality()
98
- self.assertEqual(torch.Tensor.b2, 2)
99
-
100
- self.assertEqual(torch.b2, 2)
101
- self.assertEqual(torch.nn.functional.b2, 2)
102
- self.assertEqual(dist.b2, 2)
103
- self.assertEqual(dist.distributed_c10d.b2, 2)
104
- #if not is_gpu and not torch_without_guard_version:
105
- #self.assertEqual(torch_npu.distributed.b2, 2)
106
- #self.assertEqual(torch_npu.distributed.distributed_c10d.b2, 2)
107
- if torch_version_above_2:
108
- self.assertEqual(torch.ops.aten.b2, 2)
109
- self.assertEqual(torch._VF.b2, 2)
110
- #if not is_gpu:
111
- #self.assertEqual(torch_npu.b2, 2)
112
-
113
- def test_initialize_hook(self):
114
- def hook_test():
115
- pass
116
-
117
- reg = ApiRegistry()
118
- reg.initialize_hook(hook_test)
119
- empty_list = []
120
- self.assertFalse(empty_list==reg.tensor_hook_attr)
121
- self.assertFalse(empty_list==reg.torch_hook_attr)
122
- self.assertFalse(empty_list==reg.functional_hook_attr)
123
- self.assertFalse(empty_list==reg.distributed_hook_attr)
124
- self.assertFalse(empty_list==reg.npu_distributed_hook_attr)
125
- if torch_version_above_2:
126
- #print(True)
127
- self.assertFalse(empty_list==reg.aten_hook_attr)
128
- if not is_gpu:
129
- #print(True)
130
- self.assertFalse(empty_list==reg.torch_npu_hook_attr)
@@ -1,42 +0,0 @@
1
- import unittest
2
- from unittest.mock import patch, Mock
3
-
4
- from msprobe.pytorch.hook_module.hook_module import HOOKModule
5
-
6
- class TestHookModule(unittest.TestCase):
7
- def test_call_1(self):
8
- def forward_pre_hook():
9
- return "result_input", "result_kwargs"
10
- def forward_hook():
11
- return 2
12
- def backward_hook():
13
- pass
14
-
15
- def hook(prefix):
16
- return forward_pre_hook, forward_hook, backward_hook
17
- HOOKModule.prefix_op_name_ = "123"
18
- test = HOOKModule(hook)
19
- test._call_func = Mock(return_value=1)
20
- result = test()
21
- self.assertEqual(result, 1)
22
-
23
- def test_call_2(self):
24
- def forward_pre_hook(nope, input, kwargs):
25
- return input, kwargs
26
- def forward_hook(nope, input, kwargs, result):
27
- return input
28
- def backward_hook():
29
- pass
30
-
31
- def hook(prefix):
32
- return forward_pre_hook, forward_hook, backward_hook
33
- HOOKModule.prefix_op_name_ = "123"
34
- input = 2
35
- test = HOOKModule(hook)
36
-
37
- def temp_forward(*input, **kwargs):
38
- return input
39
-
40
- test.forward = Mock(return_value=1)
41
- result = test(input)
42
- self.assertEqual(result, (input, ))
@@ -1,65 +0,0 @@
1
- import unittest
2
- import torch
3
- from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate, AtenOPPacketTemplate
4
-
5
-
6
- def hook(name):
7
- def forward_pre_hook(nope, input, kwargs):
8
- return input, kwargs
9
- def forward_hook(nope, input, kwargs, result):
10
- return 2
11
- def backward_hook():
12
- pass
13
-
14
- return forward_pre_hook, forward_hook, backward_hook
15
-
16
-
17
-
18
- class TestWrapAten(unittest.TestCase):
19
- def setUp(self):
20
- self.aten_op = AtenOPPacketTemplate(torch.ops.aten.convolution, hook)
21
-
22
- def test_atenop_attribute(self):
23
- if torch.__version__.split("+")[0] <= '2.0':
24
- return
25
- self.setUp()
26
- self.assertEqual(self.aten_op.default.op, torch.ops.aten.convolution.default)
27
- self.assertEqual(self.aten_op.out.op, torch.ops.aten.convolution.out)
28
-
29
- def test_atenop_forward(self):
30
- if torch.__version__.split("+")[0] <= '2.0':
31
- return
32
- self.setUp()
33
- image = torch.randn(4, 3, 24, 24)
34
- kernel = torch.randn(10, 3, 3, 3)
35
- functional_out = torch.nn.functional.conv2d(image, kernel, stride=[1, 1],
36
- padding=[1, 1], dilation=[1, 1], groups=1, bias=None)
37
- aten_out = self.aten_op(image, kernel, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)
38
- self.assertTrue(aten_out == 2)
39
-
40
- def test_atenop_overload_forward(self):
41
- if torch.__version__.split("+")[0] <= '2.0':
42
- return
43
- self.setUp()
44
- image = torch.randn(4, 3, 24, 24)
45
- kernel = torch.randn(10, 3, 3, 3)
46
- functional_out = torch.nn.functional.conv2d(image, kernel, stride=[1, 1],
47
- padding=[1, 1], dilation=[1, 1], groups=1, bias=None)
48
- aten_out = self.aten_op(image, kernel, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)
49
- self.assertTrue(aten_out == 2)
50
-
51
- def test_atenop_nonattr(self):
52
- if torch.__version__.split("+")[0] <= '2.0':
53
- return
54
- self.setUp()
55
- self.assertRaises(AttributeError, getattr, self.aten_op, "foo")
56
-
57
- def test_atenop_overloads(self):
58
- if torch.__version__.split("+")[0] <= '2.0':
59
- return
60
- self.setUp()
61
- self.assertEqual(self.aten_op.overloads(), self.aten_op.opPacket.overloads())
62
-
63
-
64
-
65
-
@@ -1,35 +0,0 @@
1
- import unittest
2
- import torch.distributed as dist
3
- from msprobe.pytorch.hook_module.wrap_distributed import *
4
-
5
- class TestWrapDistributed(unittest.TestCase):
6
- def hook(name, prefix):
7
- def forward_pre_hook(nope, input, kwargs):
8
- return input, kwargs
9
- def forward_hook(nope, input, kwargs, result):
10
- return 2
11
- def backward_hook():
12
- pass
13
- return forward_pre_hook, forward_hook, backward_hook
14
-
15
- def test_get_distributed_ops(self):
16
- ops = get_distributed_ops()
17
- self.assertIsInstance(ops, set)
18
-
19
- def test_DistributedOPTemplate(self):
20
- self.setUp()
21
- op_name = 'all_reduce'
22
- if op_name in get_distributed_ops():
23
- op = DistributedOPTemplate(op_name, self.hook)
24
- self.assertEqual(op.op_name_, op_name)
25
-
26
- def test_wrap_distributed_op(self):
27
- op_name = 'all_reduce'
28
- if op_name in get_distributed_ops():
29
- wrapped_op = wrap_distributed_op(op_name, self.hook)
30
- self.assertTrue(callable(wrapped_op))
31
-
32
- def test_wrap_distributed_ops_and_bind(self):
33
- wrap_distributed_ops_and_bind(self.hook)
34
- for op_name in get_distributed_ops():
35
- self.assertTrue(hasattr(HOOKDistributedOP, "wrap_" + str(op_name)))
@@ -1,20 +0,0 @@
1
- import unittest
2
- import torch
3
- from msprobe.pytorch.hook_module import wrap_functional as wf
4
-
5
- class TestWrapFunctional(unittest.TestCase):
6
-
7
- def test_remove_dropout(self):
8
- input_tensor = torch.randn(20, 16)
9
- wf.remove_dropout()
10
- output_tensor = torch.nn.functional.dropout(input_tensor)
11
- self.assertTrue(torch.equal(input_tensor, output_tensor))
12
-
13
- def test_get_functional_ops(self):
14
- expected_ops = {'relu', 'sigmoid', 'softmax'}
15
- actual_ops = wf.get_functional_ops()
16
- self.assertTrue(expected_ops.issubset(actual_ops))
17
-
18
- def test_wrap_functional_ops_and_bind(self):
19
- wf.wrap_functional_ops_and_bind(None)
20
- self.assertTrue(hasattr(wf.HOOKFunctionalOP, 'wrap_relu'))
@@ -1,35 +0,0 @@
1
- import unittest
2
- import torch
3
- import yaml
4
- from msprobe.pytorch.hook_module.wrap_tensor import get_tensor_ops, HOOKTensor, TensorOPTemplate, wrap_tensor_op, wrap_tensor_ops_and_bind
5
-
6
- class TestWrapTensor(unittest.TestCase):
7
-
8
- def hook(name, prefix):
9
- def forward_pre_hook(nope, input, kwargs):
10
- return input, kwargs
11
- def forward_hook(nope, input, kwargs, result):
12
- return 2
13
- def backward_hook():
14
- pass
15
- return forward_pre_hook, forward_hook, backward_hook
16
-
17
- def test_get_tensor_ops(self):
18
- result = get_tensor_ops()
19
- self.assertIsInstance(result, set)
20
-
21
- def test_HOOKTensor(self):
22
- hook_tensor = HOOKTensor()
23
- self.assertIsInstance(hook_tensor, HOOKTensor)
24
-
25
- def test_TensorOPTemplate(self):
26
- tensor_op_template = TensorOPTemplate('add', self.hook)
27
- self.assertTrue(tensor_op_template.op_name_, 'add')
28
-
29
- def test_wrap_tensor_op(self):
30
- wrapped_op = wrap_tensor_op('add', self.hook)
31
- self.assertTrue(callable(wrapped_op))
32
-
33
- def test_wrap_tensor_ops_and_bind(self):
34
- wrap_tensor_ops_and_bind(self.hook)
35
- self.assertTrue(hasattr(HOOKTensor, 'wrap_add'))
@@ -1,43 +0,0 @@
1
- import unittest
2
- import torch
3
- import yaml
4
- from msprobe.pytorch.hook_module.wrap_torch import *
5
-
6
- class TestWrapTorch(unittest.TestCase):
7
-
8
- def hook(name, prefix):
9
- def forward_pre_hook(nope, input, kwargs):
10
- return input, kwargs
11
- def forward_hook(nope, input, kwargs, result):
12
- return 2
13
- def backward_hook():
14
- pass
15
- return forward_pre_hook, forward_hook, backward_hook
16
-
17
- def setUp(self):
18
-
19
- self.op_name = 'add'
20
- self.torch_op = wrap_torch_op(self.op_name, self.hook)
21
-
22
- def test_get_torch_ops(self):
23
- self.setUp()
24
- ops = get_torch_ops()
25
- self.assertIsInstance(ops, set)
26
- self.assertIn(self.op_name, ops)
27
-
28
- def test_TorchOPTemplate(self):
29
- self.setUp()
30
- template = TorchOPTemplate(self.op_name, self.hook)
31
- self.assertEqual(template.op_name_, self.op_name)
32
- self.assertEqual(template.prefix_op_name_, "Torch." + str(self.op_name) + ".")
33
-
34
- def test_forward(self):
35
- self.setUp()
36
- template = TorchOPTemplate(self.op_name, self.hook)
37
- result = template.forward(torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]))
38
- torch.testing.assert_close(result, torch.tensor([5, 7, 9]))
39
-
40
- def test_wrap_torch_ops_and_bind(self):
41
- self.setUp()
42
- wrap_torch_ops_and_bind(self.hook)
43
- self.assertTrue(hasattr(HOOKTorchOP, "wrap_" + self.op_name))
@@ -1,11 +0,0 @@
1
- import unittest
2
- import torch
3
- from msprobe.pytorch.hook_module import wrap_vf
4
-
5
- class TestWrapVF(unittest.TestCase):
6
- def setUp(self):
7
- self.hook = lambda x: x
8
-
9
- def test_get_vf_ops(self):
10
- ops = wrap_vf.get_vf_ops()
11
- self.assertIsInstance(ops, list)
@@ -1,69 +0,0 @@
1
- from unittest import TestCase
2
- from unittest.mock import patch, mock_open
3
-
4
- from msprobe.core.common.const import Const
5
- from msprobe.pytorch.pt_config import parse_json_config, parse_task_config
6
-
7
-
8
- class TestPtConfig(TestCase):
9
- def test_parse_json_config(self):
10
- mock_json_data = {
11
- "task": "statistics",
12
- "dump_path": "./dump/",
13
- "rank": [],
14
- "step": [],
15
- "level": "L1",
16
- "seed": 1234,
17
- "statistics": {
18
- "scope": [],
19
- "list": [],
20
- "data_mode": ["all"],
21
- },
22
- "tensor": {
23
- "file_format": "npy"
24
- }
25
- }
26
- with patch("msprobe.pytorch.pt_config.os.path.join", return_value="/path/config.json"), \
27
- patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \
28
- patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data):
29
- common_config, task_config = parse_json_config(None, None)
30
- self.assertEqual(common_config.task, Const.STATISTICS)
31
- self.assertEqual(task_config.data_mode, ["all"])
32
-
33
- with patch("msprobe.pytorch.pt_config.os.path.join", return_value="/path/config.json"), \
34
- patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \
35
- patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data):
36
- common_config, task_config = parse_json_config(None, Const.TENSOR)
37
- self.assertEqual(common_config.task, Const.STATISTICS)
38
- self.assertEqual(task_config.file_format, "npy")
39
-
40
- def test_parse_task_config(self):
41
- overflow_check_config = {
42
- "overflow_check": {
43
- "overflow_nums": 1,
44
- "check_mode": "all"
45
- }
46
- }
47
- result = parse_task_config(Const.OVERFLOW_CHECK, overflow_check_config)
48
- self.assertEqual(result.overflow_num, 1)
49
- self.assertEqual(result.check_mode, "all")
50
-
51
- free_benchmark_config = {
52
- "free_benchmark": {
53
- "scope": [],
54
- "list": ["conv2d"],
55
- "fuzz_device": "npu",
56
- "pert_mode": "improve_precision",
57
- "handler_type": "check",
58
- "fuzz_level": "L1",
59
- "fuzz_stage": "forward",
60
- "if_preheat": False,
61
- "preheat_step": 15,
62
- "max_sample": 20
63
- }
64
- }
65
- result = parse_task_config(Const.FREE_BENCHMARK, free_benchmark_config)
66
- self.assertEqual(result.pert_mode, "improve_precision")
67
- self.assertEqual(result.handler_type, "check")
68
- self.assertEqual(result.preheat_step, 15)
69
- self.assertEqual(result.max_sample, 20)
@@ -1,59 +0,0 @@
1
- import unittest
2
- from unittest.mock import patch, mock_open
3
-
4
- import torch.nn as nn
5
- from msprobe.core.common.utils import Const
6
- from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
7
- from msprobe.pytorch.pt_config import parse_json_config
8
- from msprobe.pytorch.service import Service
9
-
10
-
11
- class TestService(unittest.TestCase):
12
- def setUp(self):
13
- mock_json_data = {
14
- "dump_path": "./dump/",
15
- }
16
- with patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \
17
- patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data):
18
- common_config, task_config = parse_json_config("./config.json", Const.STATISTICS)
19
- self.config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./ut_dump", "L1")
20
- self.service = Service(self.config)
21
-
22
- def test_start(self):
23
- with patch("msprobe.pytorch.service.get_rank_if_initialized", return_value=0), \
24
- patch("msprobe.pytorch.service.Service.create_dirs", return_value=None):
25
- self.service.start(None)
26
- self.assertEqual(self.service.current_rank, 0)
27
-
28
- def test_stop_and_step(self):
29
- with patch("msprobe.core.data_dump.data_collector.DataCollector.write_json", return_value=None):
30
- self.service.stop()
31
- self.assertFalse(self.service.switch)
32
-
33
- self.service.step()
34
- self.assertEqual(self.service.current_iter, 1)
35
-
36
- def test_register_hook_new(self):
37
- class TestModule(nn.Module):
38
- def __init__(self) -> None:
39
- super().__init__()
40
- self.linear = nn.Linear(in_features=8, out_features=4)
41
-
42
- def forward(self, x):
43
- x = self.linear(x)
44
- return x
45
-
46
- self.service.model = TestModule()
47
- self.config.level = "L0"
48
- with patch("msprobe.pytorch.service.logger.info_on_rank_0") as mock_logger, \
49
- patch("msprobe.pytorch.service.remove_dropout", return_value=None):
50
- self.service.register_hook_new()
51
- self.assertEqual(mock_logger.call_count, 2)
52
-
53
- def test_create_dirs(self):
54
- with patch("msprobe.pytorch.service.Path.mkdir", return_value=None), \
55
- patch("msprobe.core.common.file_check.FileChecker.common_check", return_value=None), \
56
- patch("msprobe.core.data_dump.data_collector.DataCollector.update_dump_paths",
57
- return_value=None):
58
- self.service.create_dirs()
59
- self.assertEqual(self.service.dump_iter_dir, "./ut_dump/step0")
@@ -1,3 +0,0 @@
1
- Line: NA
2
- Suspect Nodes: NA
3
- Expert Advice: All data in comparison result meets the accuracy requirements.
@@ -1,9 +0,0 @@
1
- NPU Name,Bench Name,NPU Tensor Dtype,Bench Tensor Dtype,NPU Tensor Shape,Bench Tensor Shape,Cosine,MaxAbsErr,NPU max,NPU min,NPU mean,Bench max,Bench min,Bench mean,Accuracy Reached or Not,Err_message
2
- Functional_linear_0_forward_input.0,Functional_linear_0_forward_input.0,torch.float32,torch.float32,"[3, 2]","[3, 2]",1.0,0.000000,1.948258399963379,-1.0052297115325928,-0.2003595232963562,1.948258399963379,-1.0052297115325928,-0.2003595232963562,Yes,
3
- Functional_linear_0_forward_input.1,Functional_linear_0_forward_input.1,torch.float32,torch.float32,"[3, 2]","[3, 2]",1.0,0.000000,0.28375449776649475,-0.6661239266395569,-0.2789986729621887,0.28375449776649475,-0.6661239266395569,-0.2789986729621887,Yes,
4
- Functional_linear_0_forward_input.2,Functional_linear_0_forward_input.2,torch.float32,torch.float32,[3],[3],1.0,0.000000,0.2457989901304245,-0.6338542103767395,-0.14437106251716614,0.2457989901304245,-0.6338542103767395,-0.14437106251716614,Yes,
5
- Functional_linear_0_forward_output,Functional_linear_0_forward_output,torch.float32,torch.float32,"[3, 3]","[3, 3]",1.0,0.000000,0.8278868794441223,-0.8729169964790344,0.16790540516376495,0.8278868794441223,-0.8729169964790344,0.16790540516376495,Yes,
6
- Torch_relu_0_forward_input.0,Torch_relu_0_forward_input.0,torch.float32,torch.float32,"[3, 3]","[3, 3]",1.0,0.000000,0.8278868794441223,-0.8729169964790344,0.16790540516376495,0.8278868794441223,-0.8729169964790344,0.16790540516376495,Yes,
7
- Torch_relu_0_forward_output,Torch_relu_0_forward_output,torch.float32,torch.float32,"[3, 3]","[3, 3]",1.0,0.000000,0.8278868794441223,0.0,0.31367552280426025,0.8278868794441223,0.0,0.31367552280426025,Yes,
8
- Functional_relu_0_forward_input.0,Functional_relu_0_forward_input.0,torch.float32,torch.float32,"[3, 3]","[3, 3]",1.0,0.000000,0.8278868794441223,-0.8729169964790344,0.16790540516376495,0.8278868794441223,-0.8729169964790344,0.16790540516376495,Yes,
9
- Functional_relu_0_forward_output,Functional_relu_0_forward_output,torch.float32,torch.float32,"[3, 3]","[3, 3]",1.0,0.000000,0.8278868794441223,0.0,0.31367552280426025,0.8278868794441223,0.0,0.31367552280426025,Yes,
@@ -1,9 +0,0 @@
1
- NPU Name,Bench Name,NPU Tensor Dtype,Bench Tensor Dtype,NPU Tensor Shape,Bench Tensor Shape,Cosine,MaxAbsErr,NPU max,NPU min,NPU mean,Bench max,Bench min,Bench mean,Accuracy Reached or Not,Err_message
2
- ,Functional_linear_0_forward_input.0,torch.float32,torch.float32,"[3, 2]","[3, 2]",1,0,1.9482584,-1.005229712,-0.200359523,1.9482584,-1.005229712,-0.200359523,,
3
- ,Functional_linear_0_forward_input.1,torch.float32,torch.float32,"[3, 2]","[3, 2]",1,0,0.283754498,-0.666123927,-0.278998673,0.283754498,-0.666123927,-0.278998673,,
4
- ,Functional_linear_0_forward_input.2,torch.float32,torch.float32,[3],[3],1,0,0.24579899,-0.63385421,-0.144371063,0.24579899,-0.63385421,-0.144371063,,
5
- ,Functional_linear_0_forward_output,torch.float32,torch.float32,"[3, 3]","[3, 3]",1,0,0.827886879,-0.872916996,0.167905405,0.827886879,-0.872916996,0.167905405,,
6
- ,Torch_relu_0_forward_input.0,torch.float32,torch.float32,"[3, 3]","[3, 3]",1,0,0.827886879,-0.872916996,0.167905405,0.827886879,-0.872916996,0.167905405,,
7
- ,Torch_relu_0_forward_output,torch.float32,torch.float32,"[3, 3]","[3, 3]",1,0,0.827886879,0,0.313675523,0.827886879,0,0.313675523,,
8
- ,Functional_relu_0_forward_input.0,torch.float32,torch.float32,"[3, 3]","[3, 3]",1,0,0.827886879,-0.872916996,0.167905405,0.827886879,-0.872916996,0.167905405,,
9
- ,Functional_relu_0_forward_output,torch.float32,torch.float32,"[3, 3]","[3, 3]",1,0,0.827886879,0,0.313675523,0.827886879,0,0.313675523,,
@@ -1,3 +0,0 @@
1
- white_list: []
2
- error_data_path: './'
3
- precision: 14
@@ -1,8 +0,0 @@
1
- ["Functional_linear_0_forward_input.0", 1, [], "torch.float32", [3, 2], [1.948258399963379, -1.0052297115325928, -0.2003595232963562]]
2
- ["Functional_linear_0_forward_input.1", 1, [], "torch.float32", [3, 2], [0.28375449776649475, -0.6661239266395569, -0.2789986729621887]]
3
- ["Functional_linear_0_forward_input.2", 1, [], "torch.float32", [3], [0.2457989901304245, -0.6338542103767395, -0.14437106251716614]]
4
- ["Functional_linear_0_forward_output", 1, [], "torch.float32", [3, 3], [0.8278868794441223, -0.8729169964790344, 0.16790540516376495]]
5
- ["Torch_relu_0_forward_input.0", 1, [], "torch.float32", [3, 3], [0.8278868794441223, -0.8729169964790344, 0.16790540516376495]]
6
- ["Torch_relu_0_forward_output", 1, [], "torch.float32", [3, 3], [0.8278868794441223, 0.0, 0.31367552280426025]]
7
- ["Functional_relu_0_forward_input.0", 1, [], "torch.float32", [3, 3], [0.8278868794441223, -0.8729169964790344, 0.16790540516376495]]
8
- ["Functional_relu_0_forward_output", 1, [], "torch.float32", [3, 3], [0.8278868794441223, 0.0, 0.31367552280426025]]
msprobe/test/run_test.sh DELETED
@@ -1,30 +0,0 @@
1
- #!/bin/bash
2
- CUR_DIR=$(dirname $(readlink -f $0))
3
- TOP_DIR=${CUR_DIR}/..
4
- TEST_DIR=${TOP_DIR}/"test"
5
- SRC_DIR=${TOP_DIR}/../
6
-
7
- install_pytest() {
8
- if ! pip show pytest &> /dev/null; then
9
- echo "pytest not found, trying to install..."
10
- pip install pytest
11
- fi
12
-
13
- if ! pip show pytest-cov &> /dev/null; then
14
- echo "pytest-cov not found, trying to install..."
15
- pip install pytest-cov
16
- fi
17
- }
18
-
19
- run_ut() {
20
- install_pytest
21
-
22
- export PYTHONPATH=${SRC_DIR}:${PYTHONPATH}
23
- python3 run_ut.py
24
- }
25
-
26
- main() {
27
- cd ${TEST_DIR} && run_ut
28
- }
29
-
30
- main $@
msprobe/test/run_ut.py DELETED
@@ -1,58 +0,0 @@
1
- import os
2
- import shutil
3
- import subprocess
4
- import sys
5
-
6
- from msprobe.core.common.log import logger
7
-
8
-
9
- def run_ut():
10
- cur_dir = os.path.realpath(os.path.dirname(__file__))
11
- ut_path = cur_dir
12
- cov_dir = os.path.dirname(cur_dir)
13
- report_dir = os.path.join(cur_dir, "report")
14
- final_xml_path = os.path.join(report_dir, "final.xml")
15
- cov_report_path = os.path.join(report_dir, "coverage.xml")
16
-
17
- if os.path.exists(report_dir):
18
- shutil.rmtree(report_dir)
19
- os.makedirs(report_dir)
20
-
21
- pytest_cmd = [
22
- "python3", "-m", "pytest",
23
- ut_path,
24
- f"--junitxml={final_xml_path}",
25
- f"--cov={cov_dir}",
26
- "--cov-branch",
27
- f"--cov-report=xml:{cov_report_path}",
28
- ]
29
-
30
- try:
31
- with subprocess.Popen(
32
- pytest_cmd,
33
- shell=False,
34
- stdout=subprocess.PIPE,
35
- stderr=subprocess.STDOUT,
36
- text=True,
37
- ) as proc:
38
- for line in proc.stdout:
39
- logger.info(line.strip())
40
-
41
- proc.wait()
42
-
43
- if proc.returncode == 0:
44
- logger.info("Unit tests executed successfully.")
45
- return True
46
- else:
47
- logger.error("Unit tests execution failed.")
48
- return False
49
- except Exception as e:
50
- logger.error(f"An error occurred during test execution: {e}")
51
- return False
52
-
53
-
54
- if __name__ == "__main__":
55
- if run_ut():
56
- sys.exit(0)
57
- else:
58
- sys.exit(1)