mindstudio-probe 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
  2. mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
  3. mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
  4. mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
  5. mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
  6. mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
  7. msprobe/README.md +182 -0
  8. msprobe/__init__.py +0 -0
  9. msprobe/config/README.md +397 -0
  10. msprobe/config/config.json +28 -0
  11. msprobe/config/img/free_benchmark.png +0 -0
  12. msprobe/core/common/const.py +241 -0
  13. msprobe/core/common/exceptions.py +88 -0
  14. msprobe/core/common/file_check.py +265 -0
  15. msprobe/core/common/log.py +55 -0
  16. msprobe/core/common/utils.py +516 -0
  17. msprobe/core/common_config.py +58 -0
  18. msprobe/core/data_dump/data_collector.py +140 -0
  19. msprobe/core/data_dump/data_processor/base.py +245 -0
  20. msprobe/core/data_dump/data_processor/factory.py +61 -0
  21. msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
  22. msprobe/core/data_dump/json_writer.py +116 -0
  23. msprobe/core/data_dump/scope.py +178 -0
  24. msprobe/mindspore/__init__.py +1 -0
  25. msprobe/mindspore/debugger/__init__.py +0 -0
  26. msprobe/mindspore/debugger/debugger_config.py +51 -0
  27. msprobe/mindspore/debugger/precision_debugger.py +32 -0
  28. msprobe/mindspore/doc/dump.md +65 -0
  29. msprobe/mindspore/dump/__init__.py +0 -0
  30. msprobe/mindspore/dump/api_kbk_dump.py +55 -0
  31. msprobe/mindspore/dump/dump_tool_factory.py +38 -0
  32. msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
  33. msprobe/mindspore/ms_config.py +78 -0
  34. msprobe/mindspore/overflow_check/__init__.py +0 -0
  35. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
  36. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
  37. msprobe/mindspore/task_handler_factory.py +21 -0
  38. msprobe/msprobe.py +67 -0
  39. msprobe/pytorch/__init__.py +4 -0
  40. msprobe/pytorch/advisor/advisor.py +124 -0
  41. msprobe/pytorch/advisor/advisor_const.py +59 -0
  42. msprobe/pytorch/advisor/advisor_result.py +58 -0
  43. msprobe/pytorch/api_accuracy_checker/.keep +0 -0
  44. msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
  45. msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
  46. msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
  47. msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
  48. msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
  49. msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
  50. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
  51. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
  52. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
  53. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
  54. msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
  55. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
  58. msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
  59. msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
  60. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
  61. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
  62. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
  63. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
  64. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
  65. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
  66. msprobe/pytorch/common/__init__.py +2 -0
  67. msprobe/pytorch/common/compare_script.template +14 -0
  68. msprobe/pytorch/common/log.py +32 -0
  69. msprobe/pytorch/common/parse_json.py +37 -0
  70. msprobe/pytorch/common/utils.py +224 -0
  71. msprobe/pytorch/compare/acc_compare.py +1024 -0
  72. msprobe/pytorch/compare/distributed_compare.py +111 -0
  73. msprobe/pytorch/compare/highlight.py +100 -0
  74. msprobe/pytorch/compare/mapping.yaml +607 -0
  75. msprobe/pytorch/compare/match.py +36 -0
  76. msprobe/pytorch/compare/npy_compare.py +244 -0
  77. msprobe/pytorch/debugger/__init__.py +0 -0
  78. msprobe/pytorch/debugger/debugger_config.py +86 -0
  79. msprobe/pytorch/debugger/precision_debugger.py +95 -0
  80. msprobe/pytorch/doc/FAQ.md +193 -0
  81. msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
  82. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
  83. msprobe/pytorch/doc/dump.md +207 -0
  84. msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
  85. msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
  86. msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
  87. msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
  88. msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
  89. msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
  90. msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
  91. msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
  92. msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
  93. msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
  94. msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
  95. msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
  96. msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
  97. msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
  98. msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
  99. msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
  100. msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
  101. msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
  102. msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
  103. msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
  104. msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
  105. msprobe/pytorch/doc/img/cpu_info.png +0 -0
  106. msprobe/pytorch/doc/img/module_compare.png +0 -0
  107. msprobe/pytorch/doc/parse_tool.md +286 -0
  108. msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
  109. msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
  110. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
  111. msprobe/pytorch/doc/run_overflow_check.md +25 -0
  112. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
  113. msprobe/pytorch/free_benchmark/__init__.py +8 -0
  114. msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
  115. msprobe/pytorch/free_benchmark/common/constant.py +67 -0
  116. msprobe/pytorch/free_benchmark/common/counter.py +72 -0
  117. msprobe/pytorch/free_benchmark/common/enums.py +37 -0
  118. msprobe/pytorch/free_benchmark/common/params.py +129 -0
  119. msprobe/pytorch/free_benchmark/common/utils.py +98 -0
  120. msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
  121. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
  122. msprobe/pytorch/free_benchmark/main.py +102 -0
  123. msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
  124. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
  125. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
  126. msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
  127. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
  132. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
  133. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
  134. msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
  135. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
  136. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
  137. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
  138. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
  139. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
  140. msprobe/pytorch/functional/__init__.py +0 -0
  141. msprobe/pytorch/functional/data_processor.py +0 -0
  142. msprobe/pytorch/functional/dump_module.py +39 -0
  143. msprobe/pytorch/hook_module/__init__.py +1 -0
  144. msprobe/pytorch/hook_module/api_registry.py +161 -0
  145. msprobe/pytorch/hook_module/hook_module.py +109 -0
  146. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
  147. msprobe/pytorch/hook_module/utils.py +29 -0
  148. msprobe/pytorch/hook_module/wrap_aten.py +100 -0
  149. msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
  150. msprobe/pytorch/hook_module/wrap_functional.py +108 -0
  151. msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
  152. msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
  153. msprobe/pytorch/hook_module/wrap_torch.py +88 -0
  154. msprobe/pytorch/hook_module/wrap_vf.py +64 -0
  155. msprobe/pytorch/module_processer.py +98 -0
  156. msprobe/pytorch/online_dispatch/__init__.py +20 -0
  157. msprobe/pytorch/online_dispatch/compare.py +236 -0
  158. msprobe/pytorch/online_dispatch/dispatch.py +274 -0
  159. msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
  160. msprobe/pytorch/online_dispatch/single_compare.py +391 -0
  161. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
  162. msprobe/pytorch/online_dispatch/utils.py +187 -0
  163. msprobe/pytorch/parse.py +4 -0
  164. msprobe/pytorch/parse_tool/__init__.py +0 -0
  165. msprobe/pytorch/parse_tool/cli.py +32 -0
  166. msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
  167. msprobe/pytorch/parse_tool/lib/compare.py +259 -0
  168. msprobe/pytorch/parse_tool/lib/config.py +51 -0
  169. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
  170. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
  171. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
  172. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
  173. msprobe/pytorch/parse_tool/lib/utils.py +367 -0
  174. msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
  175. msprobe/pytorch/pt_config.py +93 -0
  176. msprobe/pytorch/service.py +167 -0
  177. msprobe/test/core_ut/common/test_utils.py +345 -0
  178. msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
  179. msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
  180. msprobe/test/core_ut/data_dump/test_scope.py +151 -0
  181. msprobe/test/core_ut/test_common_config.py +152 -0
  182. msprobe/test/core_ut/test_file_check.py +218 -0
  183. msprobe/test/core_ut/test_log.py +109 -0
  184. msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
  185. msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
  186. msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
  187. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
  188. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
  189. msprobe/test/mindspore_ut/test_ms_config.py +69 -0
  190. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
  191. msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
  192. msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
  193. msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
  194. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
  195. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
  196. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
  197. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
  198. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
  199. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
  200. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
  201. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
  202. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
  203. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
  204. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
  205. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
  206. msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
  207. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
  208. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
  209. msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
  210. msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
  211. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
  212. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
  213. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
  214. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
  215. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
  216. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
  217. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
  218. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
  219. msprobe/test/pytorch_ut/test_pt_config.py +69 -0
  220. msprobe/test/pytorch_ut/test_service.py +59 -0
  221. msprobe/test/resources/advisor.txt +3 -0
  222. msprobe/test/resources/compare_result_20230703104808.csv +9 -0
  223. msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
  224. msprobe/test/resources/config.yaml +3 -0
  225. msprobe/test/resources/npu_test.pkl +8 -0
  226. msprobe/test/run_test.sh +30 -0
  227. msprobe/test/run_ut.py +58 -0
  228. msprobe/test/test_module_processer.py +64 -0
@@ -0,0 +1,125 @@
1
+ import csv
2
+ import os
3
+ import shutil
4
+ import time
5
+ import unittest
6
+
7
+ import numpy as np
8
+ import torch.nn.functional
9
+
10
+ from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
11
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
12
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import UtDataInfo
13
+
14
+ current_time = time.strftime("%Y%m%d%H%M%S")
15
+ RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
16
+ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + '.csv'
17
+ base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
18
+
19
+
20
+ class TestCompare(unittest.TestCase):
21
+ def setUp(self):
22
+ self.output_path = os.path.join(base_dir, "../compare_result")
23
+ os.mkdir(self.output_path, mode=0o750)
24
+ self.result_csv_path = os.path.join(self.output_path, RESULT_FILE_NAME)
25
+ self.details_csv_path = os.path.join(self.output_path, DETAILS_FILE_NAME)
26
+ self.is_continue_run_ut = False
27
+ self.compare = Comparator(self.result_csv_path, self.details_csv_path, self.is_continue_run_ut)
28
+
29
+ def tearDown(self) -> None:
30
+ if os.path.exists(self.output_path):
31
+ shutil.rmtree(self.output_path)
32
+
33
+ def test_compare_dropout(self):
34
+ dummy_input = torch.randn(100, 100)
35
+ bench_out = torch.nn.functional.dropout2d(dummy_input, 0.3)
36
+ npu_out = torch.nn.functional.dropout2d(dummy_input, 0.3)
37
+ self.assertTrue(self.compare._compare_dropout(bench_out, npu_out))
38
+
39
+ def test_compare_core_wrapper(self):
40
+ dummy_input = torch.randn(100, 100)
41
+ bench_out, npu_out = dummy_input, dummy_input
42
+ test_final_success, detailed_result_total = self.compare._compare_core_wrapper("api", bench_out, npu_out)
43
+ actual_cosine_similarity = detailed_result_total[0][3]
44
+ # 设置一个小的公差值
45
+ tolerance = 1e-4
46
+ # 判断实际的余弦相似度值是否在预期值的公差范围内
47
+ self.assertTrue(np.isclose(actual_cosine_similarity, 1.0, atol=tolerance))
48
+ # 对其他值进行比较,确保它们符合预期
49
+ detailed_result_total[0][3] = 1.0
50
+ self.assertEqual(detailed_result_total, [['torch.float32', 'torch.float32', (100, 100), 1.0, 0.0, ' ', ' ', ' ',
51
+ ' ', 0.0, 0.0, 0, 0.0, 0.0, ' ', ' ', ' ', ' ', ' ', ' ', 'pass',
52
+ '\nMax abs error is less than 0.001, consider as pass, skip other check and set to SPACE.\n']])
53
+ self.assertTrue(test_final_success)
54
+
55
+ bench_out, npu_out = [dummy_input, dummy_input], [dummy_input, dummy_input]
56
+ test_final_success, detailed_result_total = self.compare._compare_core_wrapper("api", bench_out, npu_out)
57
+ actual_cosine_similarity = detailed_result_total[0][3]
58
+ self.assertTrue(np.isclose(actual_cosine_similarity, 1.0, atol=tolerance))
59
+ actual_cosine_similarity = detailed_result_total[1][3]
60
+ self.assertTrue(np.isclose(actual_cosine_similarity, 1.0, atol=tolerance))
61
+ detailed_result_total[0][3] = 1.0
62
+ detailed_result_total[1][3] = 1.0
63
+ self.assertTrue(test_final_success)
64
+ self.assertEqual(detailed_result_total, [['torch.float32', 'torch.float32', (100, 100), 1.0, 0.0, ' ', ' ', ' ',
65
+ ' ', 0.0, 0.0, 0, 0.0, 0.0, ' ', ' ', ' ', ' ', ' ', ' ', 'pass',
66
+ '\nMax abs error is less than 0.001, consider as pass, skip other check and set to SPACE.\n'],
67
+ ['torch.float32', 'torch.float32', (100, 100), 1.0, 0.0, ' ', ' ', ' ',
68
+ ' ', 0.0, 0.0, 0, 0.0, 0.0, ' ', ' ', ' ', ' ', ' ', ' ', 'pass',
69
+ '\nMax abs error is less than 0.001, consider as pass, skip other check and set to SPACE.\n']])
70
+
71
+ def test_compare_output(self):
72
+ bench_out, npu_out = torch.randn(100, 100), torch.randn(100, 100)
73
+ bench_grad, npu_grad = [torch.randn(100, 100)], [torch.randn(100, 100)]
74
+ api_name = 'Functional.conv2d.0'
75
+ data_info = UtDataInfo(bench_grad, npu_grad, bench_out, npu_out, None, None, None)
76
+ is_fwd_success, is_bwd_success = self.compare.compare_output(api_name, data_info)
77
+ self.assertFalse(is_fwd_success)
78
+ # is_bwd_success should be checked
79
+
80
+ dummy_input = torch.randn(100, 100)
81
+ bench_out, npu_out = dummy_input, dummy_input
82
+ data_info = UtDataInfo(None, None, bench_out, npu_out, None, None, None)
83
+ is_fwd_success, is_bwd_success = self.compare.compare_output(api_name, data_info)
84
+ self.assertTrue(is_fwd_success)
85
+ self.assertTrue(is_bwd_success)
86
+
87
+ def test_record_results(self):
88
+ args = ('Functional.conv2d.0', False, 'N/A', [['torch.float64', 'torch.float32', (32, 64, 112, 112), 1.0,
89
+ 0.012798667686, 'N/A', 0.81631212311, 0.159979121213, 'N/A',
90
+ 'error', '\n']], None, 0)
91
+ self.compare.record_results(args)
92
+ with open(self.details_csv_path, 'r') as file:
93
+ csv_reader = csv.reader(file)
94
+ next(csv_reader)
95
+ api_name_list = [row[0] for row in csv_reader]
96
+ self.assertEqual(api_name_list[0], 'Functional.conv2d.0.forward.output.0')
97
+
98
+ def test_compare_torch_tensor(self):
99
+ cpu_output = torch.Tensor([1.0, 2.0, 3.0])
100
+ npu_output = torch.Tensor([1.0, 2.0, 3.0])
101
+ compare_column = CompareColumn()
102
+ status, compare_column, message = self.compare._compare_torch_tensor("api", cpu_output, npu_output,
103
+ compare_column)
104
+ self.assertEqual(status, "pass")
105
+
106
+ def test_compare_bool_tensor(self):
107
+ cpu_output = np.array([True, False, True])
108
+ npu_output = np.array([True, False, True])
109
+ self.assertEqual(self.compare._compare_bool_tensor(cpu_output, npu_output), (0.0, 'pass', ''))
110
+
111
+ def test_compare_builtin_type(self):
112
+ compare_column = CompareColumn()
113
+ bench_out = 1
114
+ npu_out = 1
115
+ status, compare_result, message = self.compare._compare_builtin_type(bench_out, npu_out, compare_column)
116
+ self.assertEqual((status, compare_result.error_rate, message), ('pass', 0, ''))
117
+
118
+ def test_compare_float_tensor(self):
119
+ cpu_output = torch.Tensor([1.0, 2.0, 3.0])
120
+ npu_output = torch.Tensor([1.0, 2.0, 3.0])
121
+ compare_column = CompareColumn()
122
+ status, compare_column, message = self.compare._compare_float_tensor("api", cpu_output.numpy(),
123
+ npu_output.numpy(),
124
+ compare_column, npu_output.dtype)
125
+ self.assertEqual(status, "pass")
@@ -0,0 +1,10 @@
1
+ import unittest
2
+
3
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
4
+
5
+
6
+ class TestCompareColumns(unittest.TestCase):
7
+
8
+ def test_api_precision_output_column(self):
9
+ col = ApiPrecisionOutputColumn()
10
+ self.assertIsInstance(col.to_column_value(), list)
@@ -0,0 +1,43 @@
1
+ import unittest
2
+
3
+ import numpy as np
4
+
5
+ from msprobe.pytorch.api_accuracy_checker.common.utils import CompareException
6
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, convert_str_to_float
7
+
8
+
9
+ class TestCompareUtils(unittest.TestCase):
10
+ def test_check_dtype_comparable(self):
11
+ x = np.array([1, 2, 3], dtype=np.int32)
12
+ y = np.array([4, 5, 6], dtype=np.int32)
13
+ self.assertTrue(check_dtype_comparable(x, y))
14
+
15
+ x = np.array([1.0, 2.0, 3.0], dtype=np.float32)
16
+ y = np.array([4.0, 5.0, 6.0], dtype=np.float32)
17
+ self.assertTrue(check_dtype_comparable(x, y))
18
+
19
+ x = np.array([True, False, True], dtype=np.bool_)
20
+ y = np.array([False, True, False], dtype=np.bool_)
21
+ self.assertTrue(check_dtype_comparable(x, y))
22
+
23
+ x = np.array([1, 2, 3], dtype=np.int32)
24
+ y = np.array([4.0, 5.0, 6.0], dtype=np.float32)
25
+ self.assertFalse(check_dtype_comparable(x, y))
26
+
27
+ x = np.array([1, 2, 3], dtype=np.int32)
28
+ y = np.array([True, False, True], dtype=np.bool_)
29
+ self.assertFalse(check_dtype_comparable(x, y))
30
+
31
+ def test_convert_str_to_float_when_valid_float(self):
32
+ self.assertEqual(convert_str_to_float("123.45"), 123.45)
33
+
34
+ def test_convert_str_to_float_when_valid_int(self):
35
+ self.assertEqual(convert_str_to_float("123.0"), 123.0)
36
+
37
+ def test_convert_str_to_float_when_valid_int_with_spaces(self):
38
+ self.assertEqual(convert_str_to_float(" 123.0 "), 123.0)
39
+
40
+ def test_convert_str_to_float_when_empty_string(self):
41
+ with self.assertRaises(CompareException) as cm:
42
+ convert_str_to_float('')
43
+ self.assertEqual(cm.exception.code, CompareException.INVALID_DATA_ERROR)
@@ -0,0 +1,179 @@
1
+ {
2
+ "task": "statistics",
3
+ "level": "mix",
4
+ "dump_data_dir": null,
5
+ "data": {
6
+ "Tensor.__mul__.7.forward": {
7
+ "input_args": [
8
+ {
9
+ "type": "torch.Tensor",
10
+ "dtype": "torch.float16",
11
+ "shape": [
12
+ 2048,
13
+ 2,
14
+ 1,
15
+ 256
16
+ ],
17
+ "Max": 1.3955078125,
18
+ "Min": -1.443359375,
19
+ "Mean": -0.00013697147369384766,
20
+ "Norm": 318.5,
21
+ "requires_grad": true
22
+ },
23
+ {
24
+ "type": "torch.Tensor",
25
+ "dtype": "torch.float16",
26
+ "shape": [
27
+ 2048,
28
+ 1,
29
+ 1,
30
+ 256
31
+ ],
32
+ "Max": 1.0,
33
+ "Min": -1.0,
34
+ "Mean": 0.214599609375,
35
+ "Norm": 547.0,
36
+ "requires_grad": false
37
+ }
38
+ ],
39
+ "input_kwargs": {},
40
+ "output": [
41
+ {
42
+ "type": "torch.Tensor",
43
+ "dtype": "torch.float16",
44
+ "shape": [
45
+ 2048,
46
+ 2,
47
+ 1,
48
+ 256
49
+ ],
50
+ "Max": 1.3564453125,
51
+ "Min": -1.443359375,
52
+ "Mean": -0.0014209747314453125,
53
+ "Norm": 240.125,
54
+ "requires_grad": true
55
+ }
56
+ ]
57
+ },
58
+ "Torch.chunk.4.forward": {
59
+ "input_args": [
60
+ {
61
+ "type": "torch.Tensor",
62
+ "dtype": "torch.float16",
63
+ "shape": [
64
+ 2048,
65
+ 2,
66
+ 1,
67
+ 256
68
+ ],
69
+ "Max": 1.3955078125,
70
+ "Min": -1.443359375,
71
+ "Mean": -0.00013697147369384766,
72
+ "Norm": 318.5,
73
+ "requires_grad": true
74
+ },
75
+ {
76
+ "type": "int",
77
+ "value": 2
78
+ }
79
+ ],
80
+ "input_kwargs": {
81
+ "dim": {
82
+ "type": "int",
83
+ "value": -1
84
+ }
85
+ },
86
+ "output": [
87
+ {
88
+ "type": "torch.Tensor",
89
+ "dtype": "torch.float16",
90
+ "shape": [
91
+ 2048,
92
+ 2,
93
+ 1,
94
+ 128
95
+ ],
96
+ "Max": 1.3720703125,
97
+ "Min": -1.3759765625,
98
+ "Mean": 0.0015316009521484375,
99
+ "Norm": 226.25,
100
+ "requires_grad": true
101
+ },
102
+ {
103
+ "type": "torch.Tensor",
104
+ "dtype": "torch.float16",
105
+ "shape": [
106
+ 2048,
107
+ 2,
108
+ 1,
109
+ 128
110
+ ],
111
+ "Max": 1.3955078125,
112
+ "Min": -1.443359375,
113
+ "Mean": -0.0018053054809570312,
114
+ "Norm": 224.375,
115
+ "requires_grad": true
116
+ }
117
+ ]
118
+ },
119
+ "Torch.cat.8.forward": {
120
+ "input_args": [
121
+ [
122
+ {
123
+ "type": "torch.Tensor",
124
+ "dtype": "torch.float16",
125
+ "shape": [
126
+ 2048,
127
+ 2,
128
+ 1,
129
+ 128
130
+ ],
131
+ "Max": 1.443359375,
132
+ "Min": -1.3955078125,
133
+ "Mean": 0.0018053054809570312,
134
+ "Norm": 224.375,
135
+ "requires_grad": true
136
+ },
137
+ {
138
+ "type": "torch.Tensor",
139
+ "dtype": "torch.float16",
140
+ "shape": [
141
+ 2048,
142
+ 2,
143
+ 1,
144
+ 128
145
+ ],
146
+ "Max": 1.3720703125,
147
+ "Min": -1.3759765625,
148
+ "Mean": 0.0015316009521484375,
149
+ "Norm": 226.25,
150
+ "requires_grad": true
151
+ }
152
+ ]
153
+ ],
154
+ "input_kwargs": {
155
+ "dim": {
156
+ "type": "int",
157
+ "value": -1
158
+ }
159
+ },
160
+ "output": [
161
+ {
162
+ "type": "torch.Tensor",
163
+ "dtype": "torch.float16",
164
+ "shape": [
165
+ 2048,
166
+ 2,
167
+ 1,
168
+ 256
169
+ ],
170
+ "Max": 1.443359375,
171
+ "Min": -1.3955078125,
172
+ "Mean": 0.0016689300537109375,
173
+ "Norm": 318.5,
174
+ "requires_grad": true
175
+ }
176
+ ]
177
+ }
178
+ }
179
+ }
@@ -0,0 +1,63 @@
1
+ {
2
+ "Torch.chunk.4.forward": {
3
+ "input_args": [
4
+ {
5
+ "type": "torch.Tensor",
6
+ "dtype": "torch.float16",
7
+ "shape": [
8
+ 2048,
9
+ 2,
10
+ 1,
11
+ 256
12
+ ],
13
+ "Max": 1.3955078125,
14
+ "Min": -1.443359375,
15
+ "Mean": -0.00013697147369384766,
16
+ "Norm": 318.5,
17
+ "requires_grad": true
18
+ },
19
+ {
20
+ "type": "int",
21
+ "value": 2
22
+ }
23
+ ],
24
+ "input_kwargs": {
25
+ "dim": {
26
+ "type": "int",
27
+ "value": -1
28
+ }
29
+ },
30
+ "output": [
31
+ {
32
+ "type": "torch.Tensor",
33
+ "dtype": "torch.float16",
34
+ "shape": [
35
+ 2048,
36
+ 2,
37
+ 1,
38
+ 128
39
+ ],
40
+ "Max": 1.3720703125,
41
+ "Min": -1.3759765625,
42
+ "Mean": 0.0015316009521484375,
43
+ "Norm": 226.25,
44
+ "requires_grad": true
45
+ },
46
+ {
47
+ "type": "torch.Tensor",
48
+ "dtype": "torch.float16",
49
+ "shape": [
50
+ 2048,
51
+ 2,
52
+ 1,
53
+ 128
54
+ ],
55
+ "Max": 1.3955078125,
56
+ "Min": -1.443359375,
57
+ "Mean": -0.0018053054809570312,
58
+ "Norm": 224.375,
59
+ "requires_grad": true
60
+ }
61
+ ]
62
+ }
63
+ }
@@ -0,0 +1,99 @@
1
+ # coding=utf-8
2
+ import os
3
+ import unittest
4
+ import copy
5
+
6
+ from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import *
7
+ from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents
8
+
9
+ base_dir = os.path.dirname(os.path.realpath(__file__))
10
+ forward_file = os.path.join(base_dir, "forward.json")
11
+ forward_content = get_json_contents(forward_file)
12
+ for key, value in forward_content.items():
13
+ api_full_name = key
14
+ api_info_dict = value
15
+
16
+ max_value = 1.3945078125
17
+ min_value = -1.444359375
18
+
19
+
20
+ class TestDataGenerateMethods(unittest.TestCase):
21
+ def test_gen_api_params(self):
22
+ api_info = copy.deepcopy(api_info_dict)
23
+ args_params, kwargs_params = gen_api_params(api_info, True, None, None)
24
+ max_diff = abs(args_params[0].max() - max_value)
25
+ min_diff = abs(args_params[0].min() - min_value)
26
+ self.assertEqual(len(args_params), 2)
27
+ self.assertEqual(args_params[0].dtype, torch.float16)
28
+ self.assertEqual(args_params[1], 2)
29
+ self.assertLessEqual(max_diff, 0.001)
30
+ self.assertLessEqual(min_diff, 0.001)
31
+ self.assertEqual(args_params[0].shape, torch.Size([2048, 2, 1, 256]))
32
+ self.assertEqual(kwargs_params, {'dim': -1})
33
+
34
+ def test_gen_args(self):
35
+ args_result = gen_args(api_info_dict.get('input_args'), "conv2d")
36
+ max_diff = abs(args_result[0].max() - max_value)
37
+ min_diff = abs(args_result[0].min() - min_value)
38
+ self.assertEqual(len(args_result), 2)
39
+ self.assertEqual(args_result[0].dtype, torch.float16)
40
+ self.assertLessEqual(max_diff, 0.001)
41
+ self.assertLessEqual(min_diff, 0.001)
42
+ self.assertEqual(args_result[0].shape, torch.Size([2048, 2, 1, 256]))
43
+
44
+ def test_gen_data(self):
45
+ data = gen_data(api_info_dict.get('input_args')[0], "conv2d", True, None, None)
46
+ max_diff = abs(data.max() - max_value)
47
+ min_diff = abs(data.min() - min_value)
48
+ self.assertEqual(data.dtype, torch.float16)
49
+ self.assertEqual(data.requires_grad, True)
50
+ self.assertLessEqual(max_diff, 0.001)
51
+ self.assertLessEqual(min_diff, 0.001)
52
+ self.assertEqual(data.shape, torch.Size([2048, 2, 1, 256]))
53
+
54
+ def test_gen_kwargs(self):
55
+ api_info = copy.deepcopy(api_info_dict)
56
+ kwargs_params = gen_kwargs(api_info, None)
57
+ self.assertEqual(kwargs_params, {'dim': -1})
58
+
59
+ def test_gen_kwargs_2(self):
60
+ k_dict = {"inplace": {"type": "bool", "value": "False"}}
61
+ for key, value in k_dict.items():
62
+ gen_torch_kwargs(k_dict, key, value)
63
+ self.assertEqual(k_dict, {'inplace': False})
64
+
65
+ def test_gen_random_tensor(self):
66
+ data = gen_random_tensor(api_info_dict.get('input_args')[0], None)
67
+ max_diff = abs(data.max() - max_value)
68
+ min_diff = abs(data.min() - min_value)
69
+ self.assertEqual(data.dtype, torch.float16)
70
+ self.assertEqual(data.requires_grad, False)
71
+ self.assertLessEqual(max_diff, 0.001)
72
+ self.assertLessEqual(min_diff, 0.001)
73
+ self.assertEqual(data.shape, torch.Size([2048, 2, 1, 256]))
74
+
75
+ def test_gen_common_tensor(self):
76
+ info = api_info_dict.get('input_args')[0]
77
+ low, high = info.get('Min'), info.get('Max')
78
+ low_origin, high_origin = info.get('Min_origin'), info.get('Max_origin')
79
+ low_info = [low, low_origin]
80
+ high_info = [high, high_origin]
81
+ data_dtype = info.get('dtype')
82
+ shape = tuple(info.get('shape'))
83
+ data = gen_common_tensor(low_info, high_info, shape, data_dtype, None)
84
+ max_diff = abs(data.max() - max_value)
85
+ min_diff = abs(data.min() - min_value)
86
+ self.assertEqual(data.dtype, torch.float16)
87
+ self.assertEqual(data.requires_grad, False)
88
+ self.assertLessEqual(max_diff, 0.001)
89
+ self.assertLessEqual(min_diff, 0.001)
90
+ self.assertEqual(data.shape, torch.Size([2048, 2, 1, 256]))
91
+
92
+ def test_gen_bool_tensor(self):
93
+ info = {"type": "torch.Tensor", "dtype": "torch.bool", "shape": [1, 1, 160, 256], "Max": 1, "Min": 0,
94
+ "requires_grad": False}
95
+ low, high = info.get("Min"), info.get("Max")
96
+ shape = tuple(info.get("shape"))
97
+ data = gen_bool_tensor(low, high, shape)
98
+ self.assertEqual(data.shape, torch.Size([1, 1, 160, 256]))
99
+ self.assertEqual(data.dtype, torch.bool)
@@ -0,0 +1,115 @@
1
+ import os
2
+ import glob
3
+ import unittest
4
+ import logging
5
+ from unittest.mock import patch, mock_open, MagicMock
6
+ import json
7
+ import signal
8
+ from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import split_json_file, signal_handler, run_parallel_ut, \
9
+ prepare_config, main, ParallelUTConfig
10
+
11
+
12
+ class TestMultiRunUT(unittest.TestCase):
13
+
14
+ def setUp(self):
15
+ self.test_json_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "dump.json")
16
+ self.test_data = {'data': {'key1': 'TRUE', 'key2': 'TRUE', 'key3': 'TRUE'}}
17
+ self.test_json_content = json.dumps(self.test_data)
18
+ self.forward_split_files_content = [
19
+ {'key1': 'TRUE', 'key2': 'TRUE'},
20
+ {'key3': 'TRUE', 'key4': 'TRUE'}
21
+ ]
22
+
23
+ @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.FileOpen')
24
+ def test_split_json_file(self, mock_FileOpen):
25
+ mock_FileOpen.return_value.__enter__.return_value = mock_open(read_data=self.test_json_content).return_value
26
+ num_splits = 2
27
+ split_files, total_items = split_json_file(self.test_json_file, num_splits, False)
28
+ self.assertEqual(len(split_files), num_splits)
29
+ self.assertEqual(total_items, len(self.test_data.get('data')))
30
+
31
+
32
+ @patch('subprocess.Popen')
33
+ @patch('os.path.exists', return_value=True)
34
+ @patch('builtins.open', new_callable=mock_open)
35
+ @patch('json.load', side_effect=lambda f: {'key1': 'TRUE', 'key2': 'TRUE'})
36
+ def test_run_parallel_ut(self, mock_json_load, mock_file, mock_exists, mock_popen):
37
+ mock_process = MagicMock()
38
+ mock_process.poll.side_effect = [None, None, 1]
39
+ mock_process.stdout.readline.side_effect = ['[ERROR] Test Error Message\n', '']
40
+ mock_popen.return_value = mock_process
41
+
42
+ config = ParallelUTConfig(
43
+ api_files=['test.json'],
44
+ out_path='./',
45
+ num_splits=2,
46
+ save_error_data_flag=True,
47
+ jit_compile_flag=False,
48
+ device_id=[0, 1],
49
+ result_csv_path='result.csv',
50
+ total_items=2,
51
+ real_data_path=None
52
+ )
53
+
54
+ mock_file.side_effect = [
55
+ mock_open(read_data=json.dumps(self.forward_split_files_content[0])).return_value,
56
+ mock_open(read_data=json.dumps(self.forward_split_files_content[1])).return_value
57
+ ]
58
+
59
+ run_parallel_ut(config)
60
+
61
+ mock_popen.assert_called()
62
+ mock_exists.assert_called()
63
+
64
+ @patch('os.remove')
65
+ @patch('os.path.realpath', side_effect=lambda x: x)
66
+ @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.check_link')
67
+ @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.check_file_suffix')
68
+ @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.FileChecker')
69
+ @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.split_json_file',
70
+ return_value=(['forward_split1.json', 'forward_split2.json'], 2))
71
+ def test_prepare_config(self, mock_split_json_file, mock_FileChecker, mock_check_file_suffix, mock_check_link,
72
+ mock_realpath, mock_remove):
73
+ mock_FileChecker_instance = MagicMock()
74
+ mock_FileChecker_instance.common_check.return_value = './'
75
+ mock_FileChecker.return_value = mock_FileChecker_instance
76
+ args = MagicMock()
77
+ args.api_info = 'forward.json'
78
+ args.out_path = './'
79
+ args.num_splits = 2
80
+ args.save_error_data = True
81
+ args.jit_compile = False
82
+ args.device_id = [0, 1]
83
+ args.result_csv_path = None
84
+ args.real_data_path = None
85
+
86
+ config = prepare_config(args)
87
+
88
+ self.assertEqual(config.num_splits, 2)
89
+ self.assertTrue(config.save_error_data_flag)
90
+ self.assertFalse(config.jit_compile_flag)
91
+ self.assertEqual(config.device_id, [0, 1])
92
+ self.assertEqual(config.total_items, 2)
93
+
94
+
95
+ @patch('argparse.ArgumentParser.parse_args')
96
+ @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.prepare_config')
97
+ @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.run_parallel_ut')
98
+ def test_main(self, mock_run_parallel_ut, mock_prepare_config, mock_parse_args):
99
+ main()
100
+ mock_parse_args.assert_called()
101
+ mock_prepare_config.assert_called()
102
+ mock_run_parallel_ut.assert_called()
103
+
104
+ def tearDown(self):
105
+ current_directory = os.getcwd()
106
+ pattern = os.path.join(current_directory, 'accuracy_checking_*')
107
+ files = glob.glob(pattern)
108
+
109
+ for file in files:
110
+ try:
111
+ os.remove(file)
112
+ logging.info(f"Deleted file: {file}")
113
+ except Exception as e:
114
+ logging.error(f"Failed to delete file {file}: {e}")
115
+