mindstudio-probe 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
  2. mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
  3. mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
  4. mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
  5. mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
  6. mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
  7. msprobe/README.md +182 -0
  8. msprobe/__init__.py +0 -0
  9. msprobe/config/README.md +397 -0
  10. msprobe/config/config.json +28 -0
  11. msprobe/config/img/free_benchmark.png +0 -0
  12. msprobe/core/common/const.py +241 -0
  13. msprobe/core/common/exceptions.py +88 -0
  14. msprobe/core/common/file_check.py +265 -0
  15. msprobe/core/common/log.py +55 -0
  16. msprobe/core/common/utils.py +516 -0
  17. msprobe/core/common_config.py +58 -0
  18. msprobe/core/data_dump/data_collector.py +140 -0
  19. msprobe/core/data_dump/data_processor/base.py +245 -0
  20. msprobe/core/data_dump/data_processor/factory.py +61 -0
  21. msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
  22. msprobe/core/data_dump/json_writer.py +116 -0
  23. msprobe/core/data_dump/scope.py +178 -0
  24. msprobe/mindspore/__init__.py +1 -0
  25. msprobe/mindspore/debugger/__init__.py +0 -0
  26. msprobe/mindspore/debugger/debugger_config.py +51 -0
  27. msprobe/mindspore/debugger/precision_debugger.py +32 -0
  28. msprobe/mindspore/doc/dump.md +65 -0
  29. msprobe/mindspore/dump/__init__.py +0 -0
  30. msprobe/mindspore/dump/api_kbk_dump.py +55 -0
  31. msprobe/mindspore/dump/dump_tool_factory.py +38 -0
  32. msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
  33. msprobe/mindspore/ms_config.py +78 -0
  34. msprobe/mindspore/overflow_check/__init__.py +0 -0
  35. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
  36. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
  37. msprobe/mindspore/task_handler_factory.py +21 -0
  38. msprobe/msprobe.py +67 -0
  39. msprobe/pytorch/__init__.py +4 -0
  40. msprobe/pytorch/advisor/advisor.py +124 -0
  41. msprobe/pytorch/advisor/advisor_const.py +59 -0
  42. msprobe/pytorch/advisor/advisor_result.py +58 -0
  43. msprobe/pytorch/api_accuracy_checker/.keep +0 -0
  44. msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
  45. msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
  46. msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
  47. msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
  48. msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
  49. msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
  50. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
  51. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
  52. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
  53. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
  54. msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
  55. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
  58. msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
  59. msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
  60. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
  61. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
  62. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
  63. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
  64. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
  65. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
  66. msprobe/pytorch/common/__init__.py +2 -0
  67. msprobe/pytorch/common/compare_script.template +14 -0
  68. msprobe/pytorch/common/log.py +32 -0
  69. msprobe/pytorch/common/parse_json.py +37 -0
  70. msprobe/pytorch/common/utils.py +224 -0
  71. msprobe/pytorch/compare/acc_compare.py +1024 -0
  72. msprobe/pytorch/compare/distributed_compare.py +111 -0
  73. msprobe/pytorch/compare/highlight.py +100 -0
  74. msprobe/pytorch/compare/mapping.yaml +607 -0
  75. msprobe/pytorch/compare/match.py +36 -0
  76. msprobe/pytorch/compare/npy_compare.py +244 -0
  77. msprobe/pytorch/debugger/__init__.py +0 -0
  78. msprobe/pytorch/debugger/debugger_config.py +86 -0
  79. msprobe/pytorch/debugger/precision_debugger.py +95 -0
  80. msprobe/pytorch/doc/FAQ.md +193 -0
  81. msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
  82. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
  83. msprobe/pytorch/doc/dump.md +207 -0
  84. msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
  85. msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
  86. msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
  87. msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
  88. msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
  89. msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
  90. msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
  91. msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
  92. msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
  93. msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
  94. msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
  95. msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
  96. msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
  97. msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
  98. msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
  99. msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
  100. msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
  101. msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
  102. msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
  103. msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
  104. msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
  105. msprobe/pytorch/doc/img/cpu_info.png +0 -0
  106. msprobe/pytorch/doc/img/module_compare.png +0 -0
  107. msprobe/pytorch/doc/parse_tool.md +286 -0
  108. msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
  109. msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
  110. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
  111. msprobe/pytorch/doc/run_overflow_check.md +25 -0
  112. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
  113. msprobe/pytorch/free_benchmark/__init__.py +8 -0
  114. msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
  115. msprobe/pytorch/free_benchmark/common/constant.py +67 -0
  116. msprobe/pytorch/free_benchmark/common/counter.py +72 -0
  117. msprobe/pytorch/free_benchmark/common/enums.py +37 -0
  118. msprobe/pytorch/free_benchmark/common/params.py +129 -0
  119. msprobe/pytorch/free_benchmark/common/utils.py +98 -0
  120. msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
  121. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
  122. msprobe/pytorch/free_benchmark/main.py +102 -0
  123. msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
  124. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
  125. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
  126. msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
  127. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
  132. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
  133. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
  134. msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
  135. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
  136. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
  137. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
  138. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
  139. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
  140. msprobe/pytorch/functional/__init__.py +0 -0
  141. msprobe/pytorch/functional/data_processor.py +0 -0
  142. msprobe/pytorch/functional/dump_module.py +39 -0
  143. msprobe/pytorch/hook_module/__init__.py +1 -0
  144. msprobe/pytorch/hook_module/api_registry.py +161 -0
  145. msprobe/pytorch/hook_module/hook_module.py +109 -0
  146. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
  147. msprobe/pytorch/hook_module/utils.py +29 -0
  148. msprobe/pytorch/hook_module/wrap_aten.py +100 -0
  149. msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
  150. msprobe/pytorch/hook_module/wrap_functional.py +108 -0
  151. msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
  152. msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
  153. msprobe/pytorch/hook_module/wrap_torch.py +88 -0
  154. msprobe/pytorch/hook_module/wrap_vf.py +64 -0
  155. msprobe/pytorch/module_processer.py +98 -0
  156. msprobe/pytorch/online_dispatch/__init__.py +20 -0
  157. msprobe/pytorch/online_dispatch/compare.py +236 -0
  158. msprobe/pytorch/online_dispatch/dispatch.py +274 -0
  159. msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
  160. msprobe/pytorch/online_dispatch/single_compare.py +391 -0
  161. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
  162. msprobe/pytorch/online_dispatch/utils.py +187 -0
  163. msprobe/pytorch/parse.py +4 -0
  164. msprobe/pytorch/parse_tool/__init__.py +0 -0
  165. msprobe/pytorch/parse_tool/cli.py +32 -0
  166. msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
  167. msprobe/pytorch/parse_tool/lib/compare.py +259 -0
  168. msprobe/pytorch/parse_tool/lib/config.py +51 -0
  169. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
  170. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
  171. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
  172. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
  173. msprobe/pytorch/parse_tool/lib/utils.py +367 -0
  174. msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
  175. msprobe/pytorch/pt_config.py +93 -0
  176. msprobe/pytorch/service.py +167 -0
  177. msprobe/test/core_ut/common/test_utils.py +345 -0
  178. msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
  179. msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
  180. msprobe/test/core_ut/data_dump/test_scope.py +151 -0
  181. msprobe/test/core_ut/test_common_config.py +152 -0
  182. msprobe/test/core_ut/test_file_check.py +218 -0
  183. msprobe/test/core_ut/test_log.py +109 -0
  184. msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
  185. msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
  186. msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
  187. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
  188. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
  189. msprobe/test/mindspore_ut/test_ms_config.py +69 -0
  190. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
  191. msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
  192. msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
  193. msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
  194. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
  195. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
  196. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
  197. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
  198. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
  199. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
  200. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
  201. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
  202. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
  203. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
  204. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
  205. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
  206. msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
  207. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
  208. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
  209. msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
  210. msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
  211. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
  212. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
  213. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
  214. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
  215. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
  216. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
  217. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
  218. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
  219. msprobe/test/pytorch_ut/test_pt_config.py +69 -0
  220. msprobe/test/pytorch_ut/test_service.py +59 -0
  221. msprobe/test/resources/advisor.txt +3 -0
  222. msprobe/test/resources/compare_result_20230703104808.csv +9 -0
  223. msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
  224. msprobe/test/resources/config.yaml +3 -0
  225. msprobe/test/resources/npu_test.pkl +8 -0
  226. msprobe/test/run_test.sh +30 -0
  227. msprobe/test/run_ut.py +58 -0
  228. msprobe/test/test_module_processer.py +64 -0
@@ -0,0 +1,72 @@
1
+ # coding=utf-8
2
+ import os
3
+ import copy
4
+ import unittest
5
+ import torch
6
+ from unittest.mock import patch, DEFAULT
7
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import *
8
+ from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents
9
+
10
+ base_dir = os.path.dirname(os.path.realpath(__file__))
11
+ forward_file = os.path.join(base_dir, "forward.json")
12
+ forward_content = get_json_contents(forward_file)
13
+ for api_full_name, api_info_dict in forward_content.items():
14
+ api_full_name = api_full_name
15
+ api_info_dict = api_info_dict
16
+
17
+
18
+ class TestRunUtMethods(unittest.TestCase):
19
+ def test_exec_api(self):
20
+ api_info = copy.deepcopy(api_info_dict)
21
+
22
+ [api_type, api_name, _, _] = api_full_name.split(".")
23
+ args, kwargs, need_grad = get_api_info(api_info, api_name, None)
24
+ cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, True, '')
25
+ out = exec_api(api_type, api_name, cpu_args, cpu_kwargs)
26
+ self.assertEqual(out[0].dtype, torch.float32)
27
+ self.assertTrue(out[0].requires_grad)
28
+ self.assertEqual(out[0].shape, torch.Size([2048, 2, 1, 128]))
29
+
30
+ def test_generate_device_params(self):
31
+ mock_tensor = torch.rand([2, 2560, 24, 24], dtype=torch.float32, requires_grad=True)
32
+
33
+ with patch.multiple('torch.Tensor',
34
+ to=DEFAULT,
35
+ clone=DEFAULT,
36
+ detach=DEFAULT,
37
+ requires_grad_=DEFAULT,
38
+ type_as=DEFAULT,
39
+ retain_grad=DEFAULT) as mocks:
40
+ mocks['clone'].return_value = mock_tensor
41
+ mocks['detach'].return_value = mock_tensor
42
+ mocks['requires_grad_'].return_value = mock_tensor
43
+ mocks['type_as'].return_value = mock_tensor
44
+ mocks['retain_grad'].return_value = None
45
+ mocks['to'].return_value = mock_tensor
46
+
47
+ device_args, device_kwargs = generate_device_params([mock_tensor], {'inplace': False}, True, '')
48
+ self.assertEqual(len(device_args), 1)
49
+ self.assertEqual(device_args[0].dtype, torch.float32)
50
+ self.assertTrue(device_args[0].requires_grad)
51
+ self.assertEqual(device_args[0].shape, torch.Size([2, 2560, 24, 24]))
52
+ self.assertEqual(device_kwargs, {'inplace': False})
53
+
54
+ def test_generate_cpu_params(self):
55
+ api_info = copy.deepcopy(api_info_dict)
56
+ [api_type, api_name, _, _] = api_full_name.split(".")
57
+ args, kwargs, need_grad = get_api_info(api_info, api_name, None)
58
+ cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, True, '')
59
+ self.assertEqual(len(cpu_args), 2)
60
+ self.assertEqual(cpu_args[0].dtype, torch.float32)
61
+ self.assertTrue(cpu_args[0].requires_grad)
62
+ self.assertEqual(cpu_args[0].shape, torch.Size([2048, 2, 1, 256]))
63
+ self.assertEqual(cpu_kwargs, {'dim': -1})
64
+
65
+ def test_UtDataInfo(self):
66
+ data_info = UtDataInfo(None, None, None, None, None, None, None)
67
+ self.assertIsNone(data_info.bench_grad)
68
+ self.assertIsNone(data_info.device_grad)
69
+ self.assertIsNone(data_info.device_output)
70
+ self.assertIsNone(data_info.bench_output)
71
+ self.assertIsNone(data_info.grad_in)
72
+ self.assertIsNone(data_info.in_fwd_data_list)
@@ -0,0 +1,17 @@
1
+ # coding=utf-8
2
+ import unittest
3
+ from msprobe.pytorch.compare.acc_compare import rename_api
4
+
5
+ class TestUtilsMethods(unittest.TestCase):
6
+
7
+ def test_rename_api(self):
8
+ test_name_1 = "Distributed.broadcast.0.forward.input.0"
9
+ expect_name_1 = "Distributed.broadcast.input.0"
10
+ actual_name_1 = rename_api(test_name_1, "forward")
11
+ self.assertEqual(actual_name_1, expect_name_1)
12
+
13
+ test_name_2 = "Torch.sum.0.backward.output.0"
14
+ expect_name_2 = "Torch.sum.output.0"
15
+ actual_name_2 = rename_api(test_name_2, "backward")
16
+ self.assertEqual(actual_name_2, expect_name_2)
17
+
@@ -0,0 +1,105 @@
1
+ from unittest import TestCase
2
+
3
+ import torch
4
+ from msprobe.core.common.const import Const
5
+ from msprobe.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode
6
+ from msprobe.pytorch.free_benchmark.common.params import data_pre_deal
7
+ from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
8
+
9
+
10
+ class TestPerturbedLayer(TestCase):
11
+
12
+ # 对输出精度和输入精度一致算子使用升精度扰动因子时, 输出结果的精度也会提升
13
+ def test_improve_precision_layer_handle_with_out_dtype_changing(self):
14
+ api_name = "Torch.mul.0.forward"
15
+ x = torch.randn(2, 3, dtype=torch.float16)
16
+ y = torch.randn(2, 3, dtype=torch.float16)
17
+ out = torch.mul(x, y)
18
+
19
+ data_params = data_pre_deal(api_name, torch.mul, (x, y), {})
20
+ data_params.fuzz_stage = Const.FORWARD
21
+ data_params.original_result = out
22
+
23
+ layer = LayerFactory.create(
24
+ api_name, DeviceType.NPU, PerturbationMode.IMPROVE_PRECISION
25
+ )
26
+ layer.handle(data_params)
27
+ self.assertEqual(data_params.original_result.dtype, torch.float16)
28
+ self.assertEqual(layer.perturbed_value, torch.float32)
29
+ self.assertEqual(data_params.perturbed_result.dtype, torch.float32)
30
+
31
+ # 对于可迭代类型的输入, 升精度方法会遍历其中元素对支持类型输入升精度
32
+ def test_improve_precision_layer_with_iterable_inputs(self):
33
+ api_name = "iterable.0.forward"
34
+ tensor_a = torch.randn(2, 3, dtype=torch.bfloat16)
35
+ tensor_b = torch.randn(2, 3, dtype=torch.float16)
36
+ tensor_c = torch.randn(2, 3, dtype=torch.float32)
37
+ tensor_d = torch.randn(2, 3, dtype=torch.float64)
38
+ tensor_f = torch.randn(2, 3, dtype=torch.float64).to(torch.int32)
39
+ inputs = [tensor_a, tensor_b, {"c": tensor_c, "d": tensor_d}, tensor_f]
40
+
41
+ layer = LayerFactory.create(
42
+ api_name, DeviceType.NPU, PerturbationMode.IMPROVE_PRECISION
43
+ )
44
+ Perturbed_value = layer.improve_tensor_precision(inputs)
45
+ self.assertEqual(Perturbed_value[0].dtype, torch.float32)
46
+ self.assertEqual(Perturbed_value[1].dtype, torch.float32)
47
+ self.assertEqual(Perturbed_value[2]["c"].dtype, torch.float32)
48
+ self.assertEqual(Perturbed_value[2]["d"].dtype, torch.float64)
49
+ self.assertEqual(Perturbed_value[3].dtype, torch.int32)
50
+
51
+ # no_change扰动因子不会改变输入
52
+ def test_no_change_layer(self):
53
+ api_name = "nochange.0.forward"
54
+ inputs = torch.as_tensor([1e-9, 1e-2], dtype=torch.float32)
55
+ layer = LayerFactory.create(
56
+ api_name, DeviceType.NPU, PerturbationMode.NO_CHANGE
57
+ )
58
+ Perturbed_value = layer.no_change(inputs)
59
+ self.assertEqual(Perturbed_value[0], 1e-9)
60
+ self.assertEqual(Perturbed_value[1], 1e-2)
61
+
62
+ # 对于一维二维张量,change_value扰动因子会交换首尾值的位置
63
+ def test_change_value_layer(self):
64
+ api_name = "change.0.forward"
65
+ inputs_1dim = torch.as_tensor([1e-9, 1e-7, 1e-2], dtype=torch.float32)
66
+ inputs_2dim = torch.as_tensor(
67
+ [[1e-9, 1e-7, 1e-2], [1e-9, 1e-2, 1e-7]], dtype=torch.float32
68
+ )
69
+ layer = LayerFactory.create(
70
+ api_name, DeviceType.NPU, PerturbationMode.CHANGE_VALUE
71
+ )
72
+ Perturbed_value_1dim = layer.change_value(inputs_1dim)
73
+ layer.is_added = False
74
+ Perturbed_value_2dim = layer.change_value(inputs_2dim)
75
+ self.assertEqual(Perturbed_value_1dim[0], 1e-2)
76
+ self.assertEqual(Perturbed_value_1dim[2], 1e-9)
77
+ self.assertEqual(Perturbed_value_2dim[0][0], 1e-7)
78
+ self.assertEqual(Perturbed_value_2dim[-1][-1], 1e-9)
79
+
80
+ # 对于输入张量,bit_noise扰动因子对大于极小值的部分进行末尾比特翻转
81
+ def test_bit_noise_layer(self):
82
+ api_name = "bitnoise.0.forward"
83
+ inputs = torch.as_tensor(
84
+ [4096.00048828125, 16777216, 1e-38], dtype=torch.float32
85
+ )
86
+ layer = LayerFactory.create(
87
+ api_name, DeviceType.NPU, PerturbationMode.BIT_NOISE
88
+ )
89
+ Perturbed_value = layer.add_bit_noise(inputs)
90
+ self.assertEqual(Perturbed_value[0], 4096.0000000000)
91
+ self.assertEqual(Perturbed_value[1], 16777218)
92
+ self.assertEqual(Perturbed_value[2], 1e-38)
93
+
94
+ # 对于输入张量,add_noise扰动因子对大于极小值的部分增加一个小值
95
+ def test_add_noise_layer(self):
96
+ api_name = "addnoise.0.forward"
97
+ inputs = torch.as_tensor(
98
+ [1e-1, 1e-2], dtype=torch.bfloat16
99
+ )
100
+ layer = LayerFactory.create(
101
+ api_name, DeviceType.NPU, PerturbationMode.ADD_NOISE
102
+ )
103
+ Perturbed_value = layer.add_noise(inputs)
104
+ self.assertEqual(Perturbed_value[0], 1e-1+1e-4)
105
+ self.assertEqual(Perturbed_value[1], 1e-2)
@@ -0,0 +1,121 @@
1
+ from abc import ABC
2
+ from unittest import TestCase
3
+
4
+ import torch
5
+ from msprobe.core.common.const import Const
6
+ from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig, ThresholdConfig
7
+ from msprobe.pytorch.free_benchmark.common.counter import preheat_counter
8
+ from msprobe.pytorch.free_benchmark.common.enums import (
9
+ DeviceType,
10
+ FuzzLevel,
11
+ HandlerType,
12
+ PerturbationMode,
13
+ )
14
+ from msprobe.pytorch.free_benchmark.common.params import DataParams, make_handler_params
15
+ from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import (
16
+ FuzzHandlerFactory,
17
+ )
18
+
19
+
20
+ class Config(ABC):
21
+ """
22
+ 用以提供参数配置
23
+ """
24
+ def __init__(self, handler_type, preheat_config):
25
+ self.fuzz_stage = Const.FORWARD
26
+ self.handler_type = handler_type
27
+ self.fuzz_device = DeviceType.NPU
28
+ self.fuzz_level = FuzzLevel.BASE_LEVEL
29
+ self.pert_mode = PerturbationMode.IMPROVE_PRECISION
30
+ self.preheat_config = preheat_config
31
+
32
+
33
+ class TestFuzzHandler(TestCase):
34
+
35
+ def setUp(self) -> None:
36
+ origin_inputs = [
37
+ torch.as_tensor([3.01, 3.02], dtype=torch.float16),
38
+ torch.as_tensor([0.02, 0.02], dtype=torch.float16),
39
+ ]
40
+ # 将输入乘以一个大于误差阈值1.002的值,模拟二次执行出现误差
41
+ perturbed_inputs = [
42
+ (value * 1.0021).to(torch.float32).to("cpu") for value in origin_inputs
43
+ ]
44
+ origin_output = torch.add(*origin_inputs)
45
+ perturbed_output = torch.add(*perturbed_inputs)
46
+ # 实例有问题的data对象
47
+ self.data_params = DataParams(
48
+ args=origin_inputs,
49
+ kwargs={},
50
+ original_result=origin_output,
51
+ perturbed_result=perturbed_output,
52
+ origin_func=torch.add,
53
+ )
54
+ self.api_name = "add.0.forward"
55
+ self.step = 0
56
+
57
+ def test_result_handler_check(self):
58
+ # 对于check处理类,扰动前后输出不一致的情况会有UnequalRow对象生成
59
+ for _ in range(2):
60
+ config = Config(
61
+ HandlerType.CHECK, {PreheatConfig.IF_PREHEAT: False}
62
+ )
63
+ handler_params = make_handler_params(self.api_name, config, self.step)
64
+ handler = FuzzHandlerFactory.create(handler_params)
65
+ handler.handle(self.data_params)
66
+ self.assertEqual(
67
+ len(handler.get_unequal_rows()), 1
68
+ )
69
+
70
+ def test_result_handler_fix(self):
71
+ # 对于fix处理类,扰动后输出会替代原始输出, dtype和原始输出一致,但值为新输出值
72
+ config = Config(
73
+ HandlerType.FIX, {PreheatConfig.IF_PREHEAT: False}
74
+ )
75
+ handler_params = make_handler_params(self.api_name, config, self.step)
76
+ handler = FuzzHandlerFactory.create(handler_params)
77
+ result = handler.handle(self.data_params)
78
+ self.assertEqual(result.dtype, torch.float16)
79
+ self.assertEqual(result.device, self.data_params.original_result.device)
80
+ self.assertAlmostEqual(
81
+ result[0], self.data_params.perturbed_result.to(torch.float16)[0]
82
+ )
83
+ self.assertAlmostEqual(
84
+ result[1], self.data_params.perturbed_result.to(torch.float16)[1]
85
+ )
86
+
87
+ def test_result_handler_preheat(self):
88
+ # 对于preheat处理类,在预热阶段后的阈值会根据CPU调整
89
+ config = Config(
90
+ HandlerType.CHECK,
91
+ {
92
+ PreheatConfig.IF_PREHEAT: True,
93
+ PreheatConfig.PREHEAT_STEP: 4,
94
+ PreheatConfig.MAX_SAMPLE: 3
95
+ }
96
+ )
97
+ for _ in range(3):
98
+ handler_params = make_handler_params(self.api_name, config, 0)
99
+ handler = FuzzHandlerFactory.create(handler_params)
100
+ handler.handle(self.data_params)
101
+ # 通过preheat_counter的数据可以判断预热是否正常执行,这里第一个step会记录api执行次数
102
+ self.assertEqual(preheat_counter.get_one_step_used_api("add"), 3)
103
+ for step in range(1, 4):
104
+ for _ in range(3):
105
+ handler_params = make_handler_params(self.api_name, config, step)
106
+ handler = FuzzHandlerFactory.create(handler_params)
107
+ handler.handle(self.data_params)
108
+ # call time记录当前step api的调用次数
109
+ self.assertEqual(preheat_counter.get_api_called_time("add"), 3)
110
+ # 对于3个step最多采样三次的预热设置,sample time应该每次采样一例
111
+ self.assertEqual(preheat_counter.get_api_sample_time("add"), 1)
112
+ # 预热阶段,api阈值应该在两个阈值超参之间
113
+ api_threshld = preheat_counter.get_api_thd("add", "torch.float16")
114
+ self.assertLessEqual(
115
+ api_threshld,
116
+ ThresholdConfig.PREHEAT_INITIAL_THD
117
+ )
118
+ self.assertGreaterEqual(
119
+ api_threshld,
120
+ ThresholdConfig.DTYPE_PER_THD[torch.float16]
121
+ )
@@ -0,0 +1,101 @@
1
+ import functools
2
+ from abc import ABC
3
+ from unittest import TestCase
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from msprobe.core.common.const import Const
8
+ from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck
9
+ from msprobe.pytorch.free_benchmark.common.constant import CommonField, PreheatConfig
10
+ from msprobe.pytorch.free_benchmark.common.enums import (
11
+ DeviceType,
12
+ FuzzLevel,
13
+ HandlerType,
14
+ PerturbationMode,
15
+ )
16
+
17
+
18
+ class Config(ABC):
19
+ """
20
+ 用以提供参数配置
21
+ """
22
+
23
+ def __init__(self, fuzz_stage, handler_type):
24
+ self.fuzz_stage = fuzz_stage
25
+ self.handler_type = handler_type
26
+ self.fuzz_device = DeviceType.NPU
27
+ self.fuzz_level = FuzzLevel.BASE_LEVEL
28
+ self.pert_mode = PerturbationMode.IMPROVE_PRECISION
29
+ self.preheat_config = {PreheatConfig.IF_PREHEAT: False}
30
+
31
+
32
+ class WrapMul(nn.Module):
33
+ """
34
+ 用nn.module包装mul算子, 在forward中调用torch.mul
35
+ """
36
+
37
+ def __init__(self, op_name) -> None:
38
+ super().__init__()
39
+ self.op_name = op_name
40
+
41
+ def forward(self, *args, **kwargs):
42
+ return torch.mul(*args, **kwargs)
43
+
44
+
45
+ class UnequalDataProcessor(ABC):
46
+ """
47
+ 接口类, 处理检测不一致结果
48
+ """
49
+
50
+ def __init__(self) -> None:
51
+ super().__init__()
52
+ self.unequal_rows = []
53
+
54
+ def update_unequal_rows(self, unequal_rows):
55
+ self.unequal_rows.append(unequal_rows)
56
+
57
+
58
+ class TestInterface(TestCase):
59
+ def setUp(self):
60
+ self.api_name = "Torch.mul.0"
61
+
62
+ def testForwardFix(self):
63
+ # 对于前向接口,在forward钩子中开启FIX,返回结果给hook的输出
64
+ config = Config(Const.FORWARD, HandlerType.FIX)
65
+ checker = FreeBenchmarkCheck(config)
66
+ # 执行算子前向
67
+ x = torch.randn(2, 3).to(torch.float16)
68
+ y = torch.randn(2, 3).to(torch.float16)
69
+ mul_module = WrapMul(self.api_name)
70
+ out = mul_module(x, y)
71
+ # 模拟forward hook中调用无标杆前向检测接口
72
+ result, _ = checker.forward(
73
+ self.api_name,
74
+ mul_module,
75
+ args=(x, y),
76
+ kwargs={},
77
+ output=out,
78
+ )
79
+ self.assertEqual(result.dtype, torch.float32)
80
+
81
+ def testBackwardCheck(self):
82
+ # 对于反向接口,在pre forward时暂存input, 然后在backwrad后进行对比
83
+ config = Config(Const.BACKWARD, HandlerType.CHECK)
84
+ checker = FreeBenchmarkCheck(config)
85
+ processor = UnequalDataProcessor()
86
+ # 初始化输入输出
87
+ x = torch.tensor([2, 3], dtype=torch.float16, requires_grad=True)
88
+ y = torch.tensor([2, 3], dtype=torch.float16, requires_grad=True)
89
+ grad_output = torch.tensor([1,1], dtype=torch.float16)
90
+ backward_name = Const.SEP.join([self.api_name, Const.BACKWARD])
91
+ # 执行前向生成grad saver实例
92
+ mul_module = WrapMul(self.api_name)
93
+ checker.pre_forward(backward_name, mul_module, processor, (x, y), {})
94
+ # 执行算子前向和反向, 并反向获取扰动后grad_input
95
+ out = mul_module(x, y)
96
+ checker.backward(backward_name, mul_module, grad_output)
97
+ out.backward(torch.ones_like(out))
98
+ # module是否添加暂存器, 其中反向钩子执行扰动后grad_input是否正确
99
+ self.assertTrue(hasattr(mul_module, CommonField.GRADSAVER))
100
+ grad_saver = getattr(mul_module, CommonField.GRADSAVER)
101
+ self.assertEqual(grad_saver.perturbed_grad_input[0][0], 2)
@@ -0,0 +1,15 @@
1
+ import unittest
2
+
3
+ import torch.nn as nn
4
+ from msprobe.pytorch import PrecisionDebugger
5
+ from msprobe.pytorch.functional.dump_module import module_dump, module_count
6
+
7
+
8
+ class TestDumpModule(unittest.TestCase):
9
+ def setUp(self):
10
+ self.module = nn.Linear(in_features=8, out_features=4)
11
+
12
+ def test_module_dump(self):
13
+ PrecisionDebugger(dump_path="./dump")
14
+ module_dump(self.module, "TestModule")
15
+ self.assertTrue("TestModule" in module_count)
@@ -0,0 +1,130 @@
1
+ import unittest
2
+ from msprobe.pytorch.hook_module.api_registry import ApiRegistry, torch_version_above_2, is_gpu
3
+
4
+
5
+ class TestApiRegistry(unittest.TestCase):
6
+
7
+ def test_store_ori_attr(self):
8
+ class A():
9
+ a1 = 1
10
+ class B():
11
+ a = A()
12
+ b1 = 1
13
+ b2 = 2
14
+
15
+ api_list = ["a.a1", "b1", "b2"]
16
+ expect_output = {"a.a1":1, "b1":1, "b2":2}
17
+ actual_output = dict()
18
+ ApiRegistry.store_ori_attr(B, api_list, actual_output)
19
+ self.assertEqual(actual_output, expect_output)
20
+
21
+
22
+ def test_set_api_attr(self):
23
+ class A():
24
+ a1 = 1
25
+ class B():
26
+ a = A().__class__
27
+ b1 = 1
28
+
29
+ attr_dict = {"a.a2":2, "b2":2, "b3":3}
30
+ ApiRegistry.set_api_attr(B, attr_dict)
31
+
32
+ for k, v in attr_dict.items():
33
+ if '.' in k:
34
+ sub_module_name, sub_op = k.rsplit('.', 1)
35
+ sub_module = getattr(B, sub_module_name, None)
36
+
37
+ self.assertEqual(getattr(sub_module, sub_op), v)
38
+ else:
39
+ self.assertEqual(getattr(B, k), v)
40
+
41
+ def test_api_modularity(self):
42
+
43
+ import torch
44
+ import torch.distributed as dist
45
+ #import torch_npu #门禁没有安装torch_npu
46
+ from msprobe.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2
47
+
48
+
49
+
50
+ reg = ApiRegistry()
51
+ attr_dict = {"b2":2, "b3":3}
52
+ reg.tensor_hook_attr = attr_dict
53
+ reg.torch_hook_attr = attr_dict
54
+ reg.functional_hook_attr = attr_dict
55
+ reg.distributed_hook_attr = attr_dict
56
+ reg.npu_distributed_hook_attr = attr_dict
57
+ reg.aten_hook_attr = attr_dict
58
+ reg.vf_hook_attr = attr_dict
59
+ reg.torch_npu_hook_attr = attr_dict
60
+
61
+ reg.api_modularity()
62
+ self.assertEqual(torch.Tensor.b2, 2)
63
+
64
+ self.assertEqual(torch.b2, 2)
65
+ self.assertEqual(torch.nn.functional.b2, 2)
66
+ self.assertEqual(dist.b2, 2)
67
+ self.assertEqual(dist.distributed_c10d.b2, 2)
68
+ #if not is_gpu and not torch_without_guard_version:
69
+ #self.assertEqual(torch_npu.distributed.b2, 2)
70
+ #self.assertEqual(torch_npu.distributed.distributed_c10d.b2, 2)
71
+ if torch_version_above_2:
72
+ self.assertEqual(torch.ops.aten.b2, 2)
73
+ self.assertEqual(torch._VF.b2, 2)
74
+ #if not is_gpu:
75
+ #self.assertEqual(torch_npu.b2, 2)
76
+
77
+
78
+ def test_api_originality(self):
79
+ import torch
80
+ import torch.distributed as dist
81
+ #import torch_npu #门禁没有安装torch_npu
82
+ from msprobe.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2
83
+
84
+
85
+
86
+ reg = ApiRegistry()
87
+ attr_dict = {"b2":2, "b3":3}
88
+ reg.tensor_hook_attr = attr_dict
89
+ reg.torch_hook_attr = attr_dict
90
+ reg.functional_hook_attr = attr_dict
91
+ reg.distributed_hook_attr = attr_dict
92
+ reg.npu_distributed_hook_attr = attr_dict
93
+ reg.aten_hook_attr = attr_dict
94
+ reg.vf_hook_attr = attr_dict
95
+ reg.torch_npu_hook_attr = attr_dict
96
+
97
+ reg.api_originality()
98
+ self.assertEqual(torch.Tensor.b2, 2)
99
+
100
+ self.assertEqual(torch.b2, 2)
101
+ self.assertEqual(torch.nn.functional.b2, 2)
102
+ self.assertEqual(dist.b2, 2)
103
+ self.assertEqual(dist.distributed_c10d.b2, 2)
104
+ #if not is_gpu and not torch_without_guard_version:
105
+ #self.assertEqual(torch_npu.distributed.b2, 2)
106
+ #self.assertEqual(torch_npu.distributed.distributed_c10d.b2, 2)
107
+ if torch_version_above_2:
108
+ self.assertEqual(torch.ops.aten.b2, 2)
109
+ self.assertEqual(torch._VF.b2, 2)
110
+ #if not is_gpu:
111
+ #self.assertEqual(torch_npu.b2, 2)
112
+
113
+ def test_initialize_hook(self):
114
+ def hook_test():
115
+ pass
116
+
117
+ reg = ApiRegistry()
118
+ reg.initialize_hook(hook_test)
119
+ empty_list = []
120
+ self.assertFalse(empty_list==reg.tensor_hook_attr)
121
+ self.assertFalse(empty_list==reg.torch_hook_attr)
122
+ self.assertFalse(empty_list==reg.functional_hook_attr)
123
+ self.assertFalse(empty_list==reg.distributed_hook_attr)
124
+ self.assertFalse(empty_list==reg.npu_distributed_hook_attr)
125
+ if torch_version_above_2:
126
+ #print(True)
127
+ self.assertFalse(empty_list==reg.aten_hook_attr)
128
+ if not is_gpu:
129
+ #print(True)
130
+ self.assertFalse(empty_list==reg.torch_npu_hook_attr)
@@ -0,0 +1,42 @@
1
+ import unittest
2
+ from unittest.mock import patch, Mock
3
+
4
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
5
+
6
+ class TestHookModule(unittest.TestCase):
7
+ def test_call_1(self):
8
+ def forward_pre_hook():
9
+ return "result_input", "result_kwargs"
10
+ def forward_hook():
11
+ return 2
12
+ def backward_hook():
13
+ pass
14
+
15
+ def hook(prefix):
16
+ return forward_pre_hook, forward_hook, backward_hook
17
+ HOOKModule.prefix_op_name_ = "123"
18
+ test = HOOKModule(hook)
19
+ test._call_func = Mock(return_value=1)
20
+ result = test()
21
+ self.assertEqual(result, 1)
22
+
23
+ def test_call_2(self):
24
+ def forward_pre_hook(nope, input, kwargs):
25
+ return input, kwargs
26
+ def forward_hook(nope, input, kwargs, result):
27
+ return input
28
+ def backward_hook():
29
+ pass
30
+
31
+ def hook(prefix):
32
+ return forward_pre_hook, forward_hook, backward_hook
33
+ HOOKModule.prefix_op_name_ = "123"
34
+ input = 2
35
+ test = HOOKModule(hook)
36
+
37
+ def temp_forward(*input, **kwargs):
38
+ return input
39
+
40
+ test.forward = Mock(return_value=1)
41
+ result = test(input)
42
+ self.assertEqual(result, (input, ))
@@ -0,0 +1,65 @@
1
+ import unittest
2
+ import torch
3
+ from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate, AtenOPPacketTemplate
4
+
5
+
6
+ def hook(name):
7
+ def forward_pre_hook(nope, input, kwargs):
8
+ return input, kwargs
9
+ def forward_hook(nope, input, kwargs, result):
10
+ return 2
11
+ def backward_hook():
12
+ pass
13
+
14
+ return forward_pre_hook, forward_hook, backward_hook
15
+
16
+
17
+
18
+ class TestWrapAten(unittest.TestCase):
19
+ def setUp(self):
20
+ self.aten_op = AtenOPPacketTemplate(torch.ops.aten.convolution, hook)
21
+
22
+ def test_atenop_attribute(self):
23
+ if torch.__version__.split("+")[0] <= '2.0':
24
+ return
25
+ self.setUp()
26
+ self.assertEqual(self.aten_op.default.op, torch.ops.aten.convolution.default)
27
+ self.assertEqual(self.aten_op.out.op, torch.ops.aten.convolution.out)
28
+
29
+ def test_atenop_forward(self):
30
+ if torch.__version__.split("+")[0] <= '2.0':
31
+ return
32
+ self.setUp()
33
+ image = torch.randn(4, 3, 24, 24)
34
+ kernel = torch.randn(10, 3, 3, 3)
35
+ functional_out = torch.nn.functional.conv2d(image, kernel, stride=[1, 1],
36
+ padding=[1, 1], dilation=[1, 1], groups=1, bias=None)
37
+ aten_out = self.aten_op(image, kernel, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)
38
+ self.assertTrue(aten_out == 2)
39
+
40
+ def test_atenop_overload_forward(self):
41
+ if torch.__version__.split("+")[0] <= '2.0':
42
+ return
43
+ self.setUp()
44
+ image = torch.randn(4, 3, 24, 24)
45
+ kernel = torch.randn(10, 3, 3, 3)
46
+ functional_out = torch.nn.functional.conv2d(image, kernel, stride=[1, 1],
47
+ padding=[1, 1], dilation=[1, 1], groups=1, bias=None)
48
+ aten_out = self.aten_op(image, kernel, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)
49
+ self.assertTrue(aten_out == 2)
50
+
51
+ def test_atenop_nonattr(self):
52
+ if torch.__version__.split("+")[0] <= '2.0':
53
+ return
54
+ self.setUp()
55
+ self.assertRaises(AttributeError, getattr, self.aten_op, "foo")
56
+
57
+ def test_atenop_overloads(self):
58
+ if torch.__version__.split("+")[0] <= '2.0':
59
+ return
60
+ self.setUp()
61
+ self.assertEqual(self.aten_op.overloads(), self.aten_op.opPacket.overloads())
62
+
63
+
64
+
65
+