mindstudio-probe 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
  2. mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
  3. mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
  4. mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
  5. mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
  6. mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
  7. msprobe/README.md +182 -0
  8. msprobe/__init__.py +0 -0
  9. msprobe/config/README.md +397 -0
  10. msprobe/config/config.json +28 -0
  11. msprobe/config/img/free_benchmark.png +0 -0
  12. msprobe/core/common/const.py +241 -0
  13. msprobe/core/common/exceptions.py +88 -0
  14. msprobe/core/common/file_check.py +265 -0
  15. msprobe/core/common/log.py +55 -0
  16. msprobe/core/common/utils.py +516 -0
  17. msprobe/core/common_config.py +58 -0
  18. msprobe/core/data_dump/data_collector.py +140 -0
  19. msprobe/core/data_dump/data_processor/base.py +245 -0
  20. msprobe/core/data_dump/data_processor/factory.py +61 -0
  21. msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
  22. msprobe/core/data_dump/json_writer.py +116 -0
  23. msprobe/core/data_dump/scope.py +178 -0
  24. msprobe/mindspore/__init__.py +1 -0
  25. msprobe/mindspore/debugger/__init__.py +0 -0
  26. msprobe/mindspore/debugger/debugger_config.py +51 -0
  27. msprobe/mindspore/debugger/precision_debugger.py +32 -0
  28. msprobe/mindspore/doc/dump.md +65 -0
  29. msprobe/mindspore/dump/__init__.py +0 -0
  30. msprobe/mindspore/dump/api_kbk_dump.py +55 -0
  31. msprobe/mindspore/dump/dump_tool_factory.py +38 -0
  32. msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
  33. msprobe/mindspore/ms_config.py +78 -0
  34. msprobe/mindspore/overflow_check/__init__.py +0 -0
  35. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
  36. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
  37. msprobe/mindspore/task_handler_factory.py +21 -0
  38. msprobe/msprobe.py +67 -0
  39. msprobe/pytorch/__init__.py +4 -0
  40. msprobe/pytorch/advisor/advisor.py +124 -0
  41. msprobe/pytorch/advisor/advisor_const.py +59 -0
  42. msprobe/pytorch/advisor/advisor_result.py +58 -0
  43. msprobe/pytorch/api_accuracy_checker/.keep +0 -0
  44. msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
  45. msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
  46. msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
  47. msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
  48. msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
  49. msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
  50. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
  51. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
  52. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
  53. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
  54. msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
  55. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
  58. msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
  59. msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
  60. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
  61. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
  62. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
  63. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
  64. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
  65. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
  66. msprobe/pytorch/common/__init__.py +2 -0
  67. msprobe/pytorch/common/compare_script.template +14 -0
  68. msprobe/pytorch/common/log.py +32 -0
  69. msprobe/pytorch/common/parse_json.py +37 -0
  70. msprobe/pytorch/common/utils.py +224 -0
  71. msprobe/pytorch/compare/acc_compare.py +1024 -0
  72. msprobe/pytorch/compare/distributed_compare.py +111 -0
  73. msprobe/pytorch/compare/highlight.py +100 -0
  74. msprobe/pytorch/compare/mapping.yaml +607 -0
  75. msprobe/pytorch/compare/match.py +36 -0
  76. msprobe/pytorch/compare/npy_compare.py +244 -0
  77. msprobe/pytorch/debugger/__init__.py +0 -0
  78. msprobe/pytorch/debugger/debugger_config.py +86 -0
  79. msprobe/pytorch/debugger/precision_debugger.py +95 -0
  80. msprobe/pytorch/doc/FAQ.md +193 -0
  81. msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
  82. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
  83. msprobe/pytorch/doc/dump.md +207 -0
  84. msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
  85. msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
  86. msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
  87. msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
  88. msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
  89. msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
  90. msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
  91. msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
  92. msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
  93. msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
  94. msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
  95. msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
  96. msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
  97. msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
  98. msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
  99. msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
  100. msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
  101. msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
  102. msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
  103. msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
  104. msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
  105. msprobe/pytorch/doc/img/cpu_info.png +0 -0
  106. msprobe/pytorch/doc/img/module_compare.png +0 -0
  107. msprobe/pytorch/doc/parse_tool.md +286 -0
  108. msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
  109. msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
  110. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
  111. msprobe/pytorch/doc/run_overflow_check.md +25 -0
  112. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
  113. msprobe/pytorch/free_benchmark/__init__.py +8 -0
  114. msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
  115. msprobe/pytorch/free_benchmark/common/constant.py +67 -0
  116. msprobe/pytorch/free_benchmark/common/counter.py +72 -0
  117. msprobe/pytorch/free_benchmark/common/enums.py +37 -0
  118. msprobe/pytorch/free_benchmark/common/params.py +129 -0
  119. msprobe/pytorch/free_benchmark/common/utils.py +98 -0
  120. msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
  121. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
  122. msprobe/pytorch/free_benchmark/main.py +102 -0
  123. msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
  124. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
  125. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
  126. msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
  127. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
  132. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
  133. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
  134. msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
  135. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
  136. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
  137. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
  138. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
  139. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
  140. msprobe/pytorch/functional/__init__.py +0 -0
  141. msprobe/pytorch/functional/data_processor.py +0 -0
  142. msprobe/pytorch/functional/dump_module.py +39 -0
  143. msprobe/pytorch/hook_module/__init__.py +1 -0
  144. msprobe/pytorch/hook_module/api_registry.py +161 -0
  145. msprobe/pytorch/hook_module/hook_module.py +109 -0
  146. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
  147. msprobe/pytorch/hook_module/utils.py +29 -0
  148. msprobe/pytorch/hook_module/wrap_aten.py +100 -0
  149. msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
  150. msprobe/pytorch/hook_module/wrap_functional.py +108 -0
  151. msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
  152. msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
  153. msprobe/pytorch/hook_module/wrap_torch.py +88 -0
  154. msprobe/pytorch/hook_module/wrap_vf.py +64 -0
  155. msprobe/pytorch/module_processer.py +98 -0
  156. msprobe/pytorch/online_dispatch/__init__.py +20 -0
  157. msprobe/pytorch/online_dispatch/compare.py +236 -0
  158. msprobe/pytorch/online_dispatch/dispatch.py +274 -0
  159. msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
  160. msprobe/pytorch/online_dispatch/single_compare.py +391 -0
  161. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
  162. msprobe/pytorch/online_dispatch/utils.py +187 -0
  163. msprobe/pytorch/parse.py +4 -0
  164. msprobe/pytorch/parse_tool/__init__.py +0 -0
  165. msprobe/pytorch/parse_tool/cli.py +32 -0
  166. msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
  167. msprobe/pytorch/parse_tool/lib/compare.py +259 -0
  168. msprobe/pytorch/parse_tool/lib/config.py +51 -0
  169. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
  170. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
  171. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
  172. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
  173. msprobe/pytorch/parse_tool/lib/utils.py +367 -0
  174. msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
  175. msprobe/pytorch/pt_config.py +93 -0
  176. msprobe/pytorch/service.py +167 -0
  177. msprobe/test/core_ut/common/test_utils.py +345 -0
  178. msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
  179. msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
  180. msprobe/test/core_ut/data_dump/test_scope.py +151 -0
  181. msprobe/test/core_ut/test_common_config.py +152 -0
  182. msprobe/test/core_ut/test_file_check.py +218 -0
  183. msprobe/test/core_ut/test_log.py +109 -0
  184. msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
  185. msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
  186. msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
  187. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
  188. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
  189. msprobe/test/mindspore_ut/test_ms_config.py +69 -0
  190. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
  191. msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
  192. msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
  193. msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
  194. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
  195. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
  196. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
  197. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
  198. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
  199. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
  200. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
  201. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
  202. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
  203. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
  204. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
  205. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
  206. msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
  207. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
  208. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
  209. msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
  210. msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
  211. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
  212. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
  213. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
  214. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
  215. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
  216. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
  217. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
  218. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
  219. msprobe/test/pytorch_ut/test_pt_config.py +69 -0
  220. msprobe/test/pytorch_ut/test_service.py +59 -0
  221. msprobe/test/resources/advisor.txt +3 -0
  222. msprobe/test/resources/compare_result_20230703104808.csv +9 -0
  223. msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
  224. msprobe/test/resources/config.yaml +3 -0
  225. msprobe/test/resources/npu_test.pkl +8 -0
  226. msprobe/test/run_test.sh +30 -0
  227. msprobe/test/run_ut.py +58 -0
  228. msprobe/test/test_module_processer.py +64 -0
@@ -0,0 +1,35 @@
1
+ import unittest
2
+ import torch.distributed as dist
3
+ from msprobe.pytorch.hook_module.wrap_distributed import *
4
+
5
+ class TestWrapDistributed(unittest.TestCase):
6
+ def hook(name, prefix):
7
+ def forward_pre_hook(nope, input, kwargs):
8
+ return input, kwargs
9
+ def forward_hook(nope, input, kwargs, result):
10
+ return 2
11
+ def backward_hook():
12
+ pass
13
+ return forward_pre_hook, forward_hook, backward_hook
14
+
15
+ def test_get_distributed_ops(self):
16
+ ops = get_distributed_ops()
17
+ self.assertIsInstance(ops, set)
18
+
19
+ def test_DistributedOPTemplate(self):
20
+ self.setUp()
21
+ op_name = 'all_reduce'
22
+ if op_name in get_distributed_ops():
23
+ op = DistributedOPTemplate(op_name, self.hook)
24
+ self.assertEqual(op.op_name_, op_name)
25
+
26
+ def test_wrap_distributed_op(self):
27
+ op_name = 'all_reduce'
28
+ if op_name in get_distributed_ops():
29
+ wrapped_op = wrap_distributed_op(op_name, self.hook)
30
+ self.assertTrue(callable(wrapped_op))
31
+
32
+ def test_wrap_distributed_ops_and_bind(self):
33
+ wrap_distributed_ops_and_bind(self.hook)
34
+ for op_name in get_distributed_ops():
35
+ self.assertTrue(hasattr(HOOKDistributedOP, "wrap_" + str(op_name)))
@@ -0,0 +1,20 @@
1
+ import unittest
2
+ import torch
3
+ from msprobe.pytorch.hook_module import wrap_functional as wf
4
+
5
+ class TestWrapFunctional(unittest.TestCase):
6
+
7
+ def test_remove_dropout(self):
8
+ input_tensor = torch.randn(20, 16)
9
+ wf.remove_dropout()
10
+ output_tensor = torch.nn.functional.dropout(input_tensor)
11
+ self.assertTrue(torch.equal(input_tensor, output_tensor))
12
+
13
+ def test_get_functional_ops(self):
14
+ expected_ops = {'relu', 'sigmoid', 'softmax'}
15
+ actual_ops = wf.get_functional_ops()
16
+ self.assertTrue(expected_ops.issubset(actual_ops))
17
+
18
+ def test_wrap_functional_ops_and_bind(self):
19
+ wf.wrap_functional_ops_and_bind(None)
20
+ self.assertTrue(hasattr(wf.HOOKFunctionalOP, 'wrap_relu'))
@@ -0,0 +1,35 @@
1
+ import unittest
2
+ import torch
3
+ import yaml
4
+ from msprobe.pytorch.hook_module.wrap_tensor import get_tensor_ops, HOOKTensor, TensorOPTemplate, wrap_tensor_op, wrap_tensor_ops_and_bind
5
+
6
+ class TestWrapTensor(unittest.TestCase):
7
+
8
+ def hook(name, prefix):
9
+ def forward_pre_hook(nope, input, kwargs):
10
+ return input, kwargs
11
+ def forward_hook(nope, input, kwargs, result):
12
+ return 2
13
+ def backward_hook():
14
+ pass
15
+ return forward_pre_hook, forward_hook, backward_hook
16
+
17
+ def test_get_tensor_ops(self):
18
+ result = get_tensor_ops()
19
+ self.assertIsInstance(result, set)
20
+
21
+ def test_HOOKTensor(self):
22
+ hook_tensor = HOOKTensor()
23
+ self.assertIsInstance(hook_tensor, HOOKTensor)
24
+
25
+ def test_TensorOPTemplate(self):
26
+ tensor_op_template = TensorOPTemplate('add', self.hook)
27
+ self.assertTrue(tensor_op_template.op_name_, 'add')
28
+
29
+ def test_wrap_tensor_op(self):
30
+ wrapped_op = wrap_tensor_op('add', self.hook)
31
+ self.assertTrue(callable(wrapped_op))
32
+
33
+ def test_wrap_tensor_ops_and_bind(self):
34
+ wrap_tensor_ops_and_bind(self.hook)
35
+ self.assertTrue(hasattr(HOOKTensor, 'wrap_add'))
@@ -0,0 +1,43 @@
1
+ import unittest
2
+ import torch
3
+ import yaml
4
+ from msprobe.pytorch.hook_module.wrap_torch import *
5
+
6
+ class TestWrapTorch(unittest.TestCase):
7
+
8
+ def hook(name, prefix):
9
+ def forward_pre_hook(nope, input, kwargs):
10
+ return input, kwargs
11
+ def forward_hook(nope, input, kwargs, result):
12
+ return 2
13
+ def backward_hook():
14
+ pass
15
+ return forward_pre_hook, forward_hook, backward_hook
16
+
17
+ def setUp(self):
18
+
19
+ self.op_name = 'add'
20
+ self.torch_op = wrap_torch_op(self.op_name, self.hook)
21
+
22
+ def test_get_torch_ops(self):
23
+ self.setUp()
24
+ ops = get_torch_ops()
25
+ self.assertIsInstance(ops, set)
26
+ self.assertIn(self.op_name, ops)
27
+
28
+ def test_TorchOPTemplate(self):
29
+ self.setUp()
30
+ template = TorchOPTemplate(self.op_name, self.hook)
31
+ self.assertEqual(template.op_name_, self.op_name)
32
+ self.assertEqual(template.prefix_op_name_, "Torch." + str(self.op_name) + ".")
33
+
34
+ def test_forward(self):
35
+ self.setUp()
36
+ template = TorchOPTemplate(self.op_name, self.hook)
37
+ result = template.forward(torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]))
38
+ torch.testing.assert_close(result, torch.tensor([5, 7, 9]))
39
+
40
+ def test_wrap_torch_ops_and_bind(self):
41
+ self.setUp()
42
+ wrap_torch_ops_and_bind(self.hook)
43
+ self.assertTrue(hasattr(HOOKTorchOP, "wrap_" + self.op_name))
@@ -0,0 +1,11 @@
1
+ import unittest
2
+ import torch
3
+ from msprobe.pytorch.hook_module import wrap_vf
4
+
5
+ class TestWrapVF(unittest.TestCase):
6
+ def setUp(self):
7
+ self.hook = lambda x: x
8
+
9
+ def test_get_vf_ops(self):
10
+ ops = wrap_vf.get_vf_ops()
11
+ self.assertIsInstance(ops, list)
@@ -0,0 +1,69 @@
1
+ from unittest import TestCase
2
+ from unittest.mock import patch, mock_open
3
+
4
+ from msprobe.core.common.const import Const
5
+ from msprobe.pytorch.pt_config import parse_json_config, parse_task_config
6
+
7
+
8
+ class TestPtConfig(TestCase):
9
+ def test_parse_json_config(self):
10
+ mock_json_data = {
11
+ "task": "statistics",
12
+ "dump_path": "./dump/",
13
+ "rank": [],
14
+ "step": [],
15
+ "level": "L1",
16
+ "seed": 1234,
17
+ "statistics": {
18
+ "scope": [],
19
+ "list": [],
20
+ "data_mode": ["all"],
21
+ },
22
+ "tensor": {
23
+ "file_format": "npy"
24
+ }
25
+ }
26
+ with patch("msprobe.pytorch.pt_config.os.path.join", return_value="/path/config.json"), \
27
+ patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \
28
+ patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data):
29
+ common_config, task_config = parse_json_config(None, None)
30
+ self.assertEqual(common_config.task, Const.STATISTICS)
31
+ self.assertEqual(task_config.data_mode, ["all"])
32
+
33
+ with patch("msprobe.pytorch.pt_config.os.path.join", return_value="/path/config.json"), \
34
+ patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \
35
+ patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data):
36
+ common_config, task_config = parse_json_config(None, Const.TENSOR)
37
+ self.assertEqual(common_config.task, Const.STATISTICS)
38
+ self.assertEqual(task_config.file_format, "npy")
39
+
40
+ def test_parse_task_config(self):
41
+ overflow_check_config = {
42
+ "overflow_check": {
43
+ "overflow_nums": 1,
44
+ "check_mode": "all"
45
+ }
46
+ }
47
+ result = parse_task_config(Const.OVERFLOW_CHECK, overflow_check_config)
48
+ self.assertEqual(result.overflow_num, 1)
49
+ self.assertEqual(result.check_mode, "all")
50
+
51
+ free_benchmark_config = {
52
+ "free_benchmark": {
53
+ "scope": [],
54
+ "list": ["conv2d"],
55
+ "fuzz_device": "npu",
56
+ "pert_mode": "improve_precision",
57
+ "handler_type": "check",
58
+ "fuzz_level": "L1",
59
+ "fuzz_stage": "forward",
60
+ "if_preheat": False,
61
+ "preheat_step": 15,
62
+ "max_sample": 20
63
+ }
64
+ }
65
+ result = parse_task_config(Const.FREE_BENCHMARK, free_benchmark_config)
66
+ self.assertEqual(result.pert_mode, "improve_precision")
67
+ self.assertEqual(result.handler_type, "check")
68
+ self.assertEqual(result.preheat_step, 15)
69
+ self.assertEqual(result.max_sample, 20)
@@ -0,0 +1,59 @@
1
+ import unittest
2
+ from unittest.mock import patch, mock_open
3
+
4
+ import torch.nn as nn
5
+ from msprobe.core.common.utils import Const
6
+ from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
7
+ from msprobe.pytorch.pt_config import parse_json_config
8
+ from msprobe.pytorch.service import Service
9
+
10
+
11
+ class TestService(unittest.TestCase):
12
+ def setUp(self):
13
+ mock_json_data = {
14
+ "dump_path": "./dump/",
15
+ }
16
+ with patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \
17
+ patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data):
18
+ common_config, task_config = parse_json_config("./config.json", Const.STATISTICS)
19
+ self.config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./ut_dump", "L1")
20
+ self.service = Service(self.config)
21
+
22
+ def test_start(self):
23
+ with patch("msprobe.pytorch.service.get_rank_if_initialized", return_value=0), \
24
+ patch("msprobe.pytorch.service.Service.create_dirs", return_value=None):
25
+ self.service.start(None)
26
+ self.assertEqual(self.service.current_rank, 0)
27
+
28
+ def test_stop_and_step(self):
29
+ with patch("msprobe.core.data_dump.data_collector.DataCollector.write_json", return_value=None):
30
+ self.service.stop()
31
+ self.assertFalse(self.service.switch)
32
+
33
+ self.service.step()
34
+ self.assertEqual(self.service.current_iter, 1)
35
+
36
+ def test_register_hook_new(self):
37
+ class TestModule(nn.Module):
38
+ def __init__(self) -> None:
39
+ super().__init__()
40
+ self.linear = nn.Linear(in_features=8, out_features=4)
41
+
42
+ def forward(self, x):
43
+ x = self.linear(x)
44
+ return x
45
+
46
+ self.service.model = TestModule()
47
+ self.config.level = "L0"
48
+ with patch("msprobe.pytorch.service.logger.info_on_rank_0") as mock_logger, \
49
+ patch("msprobe.pytorch.service.remove_dropout", return_value=None):
50
+ self.service.register_hook_new()
51
+ self.assertEqual(mock_logger.call_count, 2)
52
+
53
+ def test_create_dirs(self):
54
+ with patch("msprobe.pytorch.service.Path.mkdir", return_value=None), \
55
+ patch("msprobe.core.common.file_check.FileChecker.common_check", return_value=None), \
56
+ patch("msprobe.core.data_dump.data_collector.DataCollector.update_dump_paths",
57
+ return_value=None):
58
+ self.service.create_dirs()
59
+ self.assertEqual(self.service.dump_iter_dir, "./ut_dump/step0")
@@ -0,0 +1,3 @@
1
+ Line: NA
2
+ Suspect Nodes: NA
3
+ Expert Advice: All data in comparison result meets the accuracy requirements.
@@ -0,0 +1,9 @@
1
+ NPU Name,Bench Name,NPU Tensor Dtype,Bench Tensor Dtype,NPU Tensor Shape,Bench Tensor Shape,Cosine,MaxAbsErr,NPU max,NPU min,NPU mean,Bench max,Bench min,Bench mean,Accuracy Reached or Not,Err_message
2
+ Functional_linear_0_forward_input.0,Functional_linear_0_forward_input.0,torch.float32,torch.float32,"[3, 2]","[3, 2]",1.0,0.000000,1.948258399963379,-1.0052297115325928,-0.2003595232963562,1.948258399963379,-1.0052297115325928,-0.2003595232963562,Yes,
3
+ Functional_linear_0_forward_input.1,Functional_linear_0_forward_input.1,torch.float32,torch.float32,"[3, 2]","[3, 2]",1.0,0.000000,0.28375449776649475,-0.6661239266395569,-0.2789986729621887,0.28375449776649475,-0.6661239266395569,-0.2789986729621887,Yes,
4
+ Functional_linear_0_forward_input.2,Functional_linear_0_forward_input.2,torch.float32,torch.float32,[3],[3],1.0,0.000000,0.2457989901304245,-0.6338542103767395,-0.14437106251716614,0.2457989901304245,-0.6338542103767395,-0.14437106251716614,Yes,
5
+ Functional_linear_0_forward_output,Functional_linear_0_forward_output,torch.float32,torch.float32,"[3, 3]","[3, 3]",1.0,0.000000,0.8278868794441223,-0.8729169964790344,0.16790540516376495,0.8278868794441223,-0.8729169964790344,0.16790540516376495,Yes,
6
+ Torch_relu_0_forward_input.0,Torch_relu_0_forward_input.0,torch.float32,torch.float32,"[3, 3]","[3, 3]",1.0,0.000000,0.8278868794441223,-0.8729169964790344,0.16790540516376495,0.8278868794441223,-0.8729169964790344,0.16790540516376495,Yes,
7
+ Torch_relu_0_forward_output,Torch_relu_0_forward_output,torch.float32,torch.float32,"[3, 3]","[3, 3]",1.0,0.000000,0.8278868794441223,0.0,0.31367552280426025,0.8278868794441223,0.0,0.31367552280426025,Yes,
8
+ Functional_relu_0_forward_input.0,Functional_relu_0_forward_input.0,torch.float32,torch.float32,"[3, 3]","[3, 3]",1.0,0.000000,0.8278868794441223,-0.8729169964790344,0.16790540516376495,0.8278868794441223,-0.8729169964790344,0.16790540516376495,Yes,
9
+ Functional_relu_0_forward_output,Functional_relu_0_forward_output,torch.float32,torch.float32,"[3, 3]","[3, 3]",1.0,0.000000,0.8278868794441223,0.0,0.31367552280426025,0.8278868794441223,0.0,0.31367552280426025,Yes,
@@ -0,0 +1,9 @@
1
+ NPU Name,Bench Name,NPU Tensor Dtype,Bench Tensor Dtype,NPU Tensor Shape,Bench Tensor Shape,Cosine,MaxAbsErr,NPU max,NPU min,NPU mean,Bench max,Bench min,Bench mean,Accuracy Reached or Not,Err_message
2
+ ,Functional_linear_0_forward_input.0,torch.float32,torch.float32,"[3, 2]","[3, 2]",1,0,1.9482584,-1.005229712,-0.200359523,1.9482584,-1.005229712,-0.200359523,,
3
+ ,Functional_linear_0_forward_input.1,torch.float32,torch.float32,"[3, 2]","[3, 2]",1,0,0.283754498,-0.666123927,-0.278998673,0.283754498,-0.666123927,-0.278998673,,
4
+ ,Functional_linear_0_forward_input.2,torch.float32,torch.float32,[3],[3],1,0,0.24579899,-0.63385421,-0.144371063,0.24579899,-0.63385421,-0.144371063,,
5
+ ,Functional_linear_0_forward_output,torch.float32,torch.float32,"[3, 3]","[3, 3]",1,0,0.827886879,-0.872916996,0.167905405,0.827886879,-0.872916996,0.167905405,,
6
+ ,Torch_relu_0_forward_input.0,torch.float32,torch.float32,"[3, 3]","[3, 3]",1,0,0.827886879,-0.872916996,0.167905405,0.827886879,-0.872916996,0.167905405,,
7
+ ,Torch_relu_0_forward_output,torch.float32,torch.float32,"[3, 3]","[3, 3]",1,0,0.827886879,0,0.313675523,0.827886879,0,0.313675523,,
8
+ ,Functional_relu_0_forward_input.0,torch.float32,torch.float32,"[3, 3]","[3, 3]",1,0,0.827886879,-0.872916996,0.167905405,0.827886879,-0.872916996,0.167905405,,
9
+ ,Functional_relu_0_forward_output,torch.float32,torch.float32,"[3, 3]","[3, 3]",1,0,0.827886879,0,0.313675523,0.827886879,0,0.313675523,,
@@ -0,0 +1,3 @@
1
+ white_list: []
2
+ error_data_path: './'
3
+ precision: 14
@@ -0,0 +1,8 @@
1
+ ["Functional_linear_0_forward_input.0", 1, [], "torch.float32", [3, 2], [1.948258399963379, -1.0052297115325928, -0.2003595232963562]]
2
+ ["Functional_linear_0_forward_input.1", 1, [], "torch.float32", [3, 2], [0.28375449776649475, -0.6661239266395569, -0.2789986729621887]]
3
+ ["Functional_linear_0_forward_input.2", 1, [], "torch.float32", [3], [0.2457989901304245, -0.6338542103767395, -0.14437106251716614]]
4
+ ["Functional_linear_0_forward_output", 1, [], "torch.float32", [3, 3], [0.8278868794441223, -0.8729169964790344, 0.16790540516376495]]
5
+ ["Torch_relu_0_forward_input.0", 1, [], "torch.float32", [3, 3], [0.8278868794441223, -0.8729169964790344, 0.16790540516376495]]
6
+ ["Torch_relu_0_forward_output", 1, [], "torch.float32", [3, 3], [0.8278868794441223, 0.0, 0.31367552280426025]]
7
+ ["Functional_relu_0_forward_input.0", 1, [], "torch.float32", [3, 3], [0.8278868794441223, -0.8729169964790344, 0.16790540516376495]]
8
+ ["Functional_relu_0_forward_output", 1, [], "torch.float32", [3, 3], [0.8278868794441223, 0.0, 0.31367552280426025]]
@@ -0,0 +1,30 @@
1
+ #!/bin/bash
2
+ CUR_DIR=$(dirname $(readlink -f $0))
3
+ TOP_DIR=${CUR_DIR}/..
4
+ TEST_DIR=${TOP_DIR}/"test"
5
+ SRC_DIR=${TOP_DIR}/../
6
+
7
+ install_pytest() {
8
+ if ! pip show pytest &> /dev/null; then
9
+ echo "pytest not found, trying to install..."
10
+ pip install pytest
11
+ fi
12
+
13
+ if ! pip show pytest-cov &> /dev/null; then
14
+ echo "pytest-cov not found, trying to install..."
15
+ pip install pytest-cov
16
+ fi
17
+ }
18
+
19
+ run_ut() {
20
+ install_pytest
21
+
22
+ export PYTHONPATH=${SRC_DIR}:${PYTHONPATH}
23
+ python3 run_ut.py
24
+ }
25
+
26
+ main() {
27
+ cd ${TEST_DIR} && run_ut
28
+ }
29
+
30
+ main $@
msprobe/test/run_ut.py ADDED
@@ -0,0 +1,58 @@
1
+ import os
2
+ import shutil
3
+ import subprocess
4
+ import sys
5
+
6
+ from msprobe.core.common.log import logger
7
+
8
+
9
+ def run_ut():
10
+ cur_dir = os.path.realpath(os.path.dirname(__file__))
11
+ ut_path = cur_dir
12
+ cov_dir = os.path.dirname(cur_dir)
13
+ report_dir = os.path.join(cur_dir, "report")
14
+ final_xml_path = os.path.join(report_dir, "final.xml")
15
+ cov_report_path = os.path.join(report_dir, "coverage.xml")
16
+
17
+ if os.path.exists(report_dir):
18
+ shutil.rmtree(report_dir)
19
+ os.makedirs(report_dir)
20
+
21
+ pytest_cmd = [
22
+ "python3", "-m", "pytest",
23
+ ut_path,
24
+ f"--junitxml={final_xml_path}",
25
+ f"--cov={cov_dir}",
26
+ "--cov-branch",
27
+ f"--cov-report=xml:{cov_report_path}",
28
+ ]
29
+
30
+ try:
31
+ with subprocess.Popen(
32
+ pytest_cmd,
33
+ shell=False,
34
+ stdout=subprocess.PIPE,
35
+ stderr=subprocess.STDOUT,
36
+ text=True,
37
+ ) as proc:
38
+ for line in proc.stdout:
39
+ logger.info(line.strip())
40
+
41
+ proc.wait()
42
+
43
+ if proc.returncode == 0:
44
+ logger.info("Unit tests executed successfully.")
45
+ return True
46
+ else:
47
+ logger.error("Unit tests execution failed.")
48
+ return False
49
+ except Exception as e:
50
+ logger.error(f"An error occurred during test execution: {e}")
51
+ return False
52
+
53
+
54
+ if __name__ == "__main__":
55
+ if run_ut():
56
+ sys.exit(0)
57
+ else:
58
+ sys.exit(1)
@@ -0,0 +1,64 @@
1
+ import unittest
2
+ from msprobe.pytorch.module_processer import ModuleProcesser
3
+ from msprobe.pytorch.common.utils import Const
4
+
5
+ import torch
6
+
7
+ class TestModuleProcesser(unittest.TestCase):
8
+ def test_filter_tensor_and_tuple(self):
9
+ def func(nope, x):
10
+ return x * 2
11
+
12
+ result_1 = ModuleProcesser.filter_tensor_and_tuple(func)(None, torch.tensor([1]))
13
+ self.assertEqual(result_1, torch.tensor([2]))
14
+
15
+ result_2 = ModuleProcesser.filter_tensor_and_tuple(func)(None, "test")
16
+ self.assertEqual(result_2, "test")
17
+
18
+ def test_clone_return_value_and_test_clone_if_tensor(self):
19
+ def func(x):
20
+ return x
21
+
22
+ input = torch.tensor([1])
23
+ input_tuple = (torch.tensor([1]), torch.tensor([2]))
24
+ input_list = [torch.tensor([1]), torch.tensor([2])]
25
+ input_dict = {"A":torch.tensor([1]), "B":torch.tensor([2])}
26
+
27
+ result = ModuleProcesser.clone_return_value(func)(input)
28
+ result[0] = 2
29
+ self.assertNotEqual(result, input)
30
+
31
+ result_tuple = ModuleProcesser.clone_return_value(func)(input_tuple)
32
+ result_tuple[0][0] = 2
33
+ self.assertNotEqual(result_tuple, input_tuple)
34
+
35
+ result_list = ModuleProcesser.clone_return_value(func)(input_list)
36
+ result_list[0][0] = 2
37
+ self.assertNotEqual(result_list, input_list)
38
+
39
+ result_dict = ModuleProcesser.clone_return_value(func)(input_dict)
40
+ result_dict["A"][0] = 2
41
+ self.assertNotEqual(result_dict, input_dict)
42
+
43
+
44
+ def test_node_hook(self):
45
+ empty_list = []
46
+ test = ModuleProcesser(None)
47
+ pre_hook = test.node_hook("test", Const.START)
48
+ self.assertIsNotNone(pre_hook)
49
+ end_hook = test.node_hook("test", "stop")
50
+ self.assertIsNotNone(end_hook)
51
+
52
+ class A():
53
+ pass
54
+ pre_hook(A, None, None)
55
+ self.assertIn("test", test.module_count)
56
+ self.assertFalse(test.module_stack==empty_list)
57
+
58
+ def test_module_count_func(self):
59
+ test = ModuleProcesser(None)
60
+ self.assertEqual(test.module_count, {})
61
+
62
+ module_name = "nope"
63
+ test.module_count_func(module_name)
64
+ self.assertEqual(test.module_count["nope"], 0)