mindstudio-probe 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
- mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
- mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
- mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
- mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
- mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
- msprobe/README.md +182 -0
- msprobe/__init__.py +0 -0
- msprobe/config/README.md +397 -0
- msprobe/config/config.json +28 -0
- msprobe/config/img/free_benchmark.png +0 -0
- msprobe/core/common/const.py +241 -0
- msprobe/core/common/exceptions.py +88 -0
- msprobe/core/common/file_check.py +265 -0
- msprobe/core/common/log.py +55 -0
- msprobe/core/common/utils.py +516 -0
- msprobe/core/common_config.py +58 -0
- msprobe/core/data_dump/data_collector.py +140 -0
- msprobe/core/data_dump/data_processor/base.py +245 -0
- msprobe/core/data_dump/data_processor/factory.py +61 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
- msprobe/core/data_dump/json_writer.py +116 -0
- msprobe/core/data_dump/scope.py +178 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/debugger/__init__.py +0 -0
- msprobe/mindspore/debugger/debugger_config.py +51 -0
- msprobe/mindspore/debugger/precision_debugger.py +32 -0
- msprobe/mindspore/doc/dump.md +65 -0
- msprobe/mindspore/dump/__init__.py +0 -0
- msprobe/mindspore/dump/api_kbk_dump.py +55 -0
- msprobe/mindspore/dump/dump_tool_factory.py +38 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
- msprobe/mindspore/ms_config.py +78 -0
- msprobe/mindspore/overflow_check/__init__.py +0 -0
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
- msprobe/mindspore/task_handler_factory.py +21 -0
- msprobe/msprobe.py +67 -0
- msprobe/pytorch/__init__.py +4 -0
- msprobe/pytorch/advisor/advisor.py +124 -0
- msprobe/pytorch/advisor/advisor_const.py +59 -0
- msprobe/pytorch/advisor/advisor_result.py +58 -0
- msprobe/pytorch/api_accuracy_checker/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
- msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
- msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
- msprobe/pytorch/common/__init__.py +2 -0
- msprobe/pytorch/common/compare_script.template +14 -0
- msprobe/pytorch/common/log.py +32 -0
- msprobe/pytorch/common/parse_json.py +37 -0
- msprobe/pytorch/common/utils.py +224 -0
- msprobe/pytorch/compare/acc_compare.py +1024 -0
- msprobe/pytorch/compare/distributed_compare.py +111 -0
- msprobe/pytorch/compare/highlight.py +100 -0
- msprobe/pytorch/compare/mapping.yaml +607 -0
- msprobe/pytorch/compare/match.py +36 -0
- msprobe/pytorch/compare/npy_compare.py +244 -0
- msprobe/pytorch/debugger/__init__.py +0 -0
- msprobe/pytorch/debugger/debugger_config.py +86 -0
- msprobe/pytorch/debugger/precision_debugger.py +95 -0
- msprobe/pytorch/doc/FAQ.md +193 -0
- msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
- msprobe/pytorch/doc/dump.md +207 -0
- msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
- msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
- msprobe/pytorch/doc/img/cpu_info.png +0 -0
- msprobe/pytorch/doc/img/module_compare.png +0 -0
- msprobe/pytorch/doc/parse_tool.md +286 -0
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
- msprobe/pytorch/doc/run_overflow_check.md +25 -0
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
- msprobe/pytorch/free_benchmark/__init__.py +8 -0
- msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/common/constant.py +67 -0
- msprobe/pytorch/free_benchmark/common/counter.py +72 -0
- msprobe/pytorch/free_benchmark/common/enums.py +37 -0
- msprobe/pytorch/free_benchmark/common/params.py +129 -0
- msprobe/pytorch/free_benchmark/common/utils.py +98 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
- msprobe/pytorch/free_benchmark/main.py +102 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
- msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
- msprobe/pytorch/functional/__init__.py +0 -0
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +39 -0
- msprobe/pytorch/hook_module/__init__.py +1 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -0
- msprobe/pytorch/hook_module/hook_module.py +109 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
- msprobe/pytorch/hook_module/utils.py +29 -0
- msprobe/pytorch/hook_module/wrap_aten.py +100 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
- msprobe/pytorch/hook_module/wrap_functional.py +108 -0
- msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
- msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
- msprobe/pytorch/hook_module/wrap_torch.py +88 -0
- msprobe/pytorch/hook_module/wrap_vf.py +64 -0
- msprobe/pytorch/module_processer.py +98 -0
- msprobe/pytorch/online_dispatch/__init__.py +20 -0
- msprobe/pytorch/online_dispatch/compare.py +236 -0
- msprobe/pytorch/online_dispatch/dispatch.py +274 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
- msprobe/pytorch/online_dispatch/single_compare.py +391 -0
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
- msprobe/pytorch/online_dispatch/utils.py +187 -0
- msprobe/pytorch/parse.py +4 -0
- msprobe/pytorch/parse_tool/__init__.py +0 -0
- msprobe/pytorch/parse_tool/cli.py +32 -0
- msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
- msprobe/pytorch/parse_tool/lib/compare.py +259 -0
- msprobe/pytorch/parse_tool/lib/config.py +51 -0
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
- msprobe/pytorch/parse_tool/lib/utils.py +367 -0
- msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
- msprobe/pytorch/pt_config.py +93 -0
- msprobe/pytorch/service.py +167 -0
- msprobe/test/core_ut/common/test_utils.py +345 -0
- msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
- msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
- msprobe/test/core_ut/data_dump/test_scope.py +151 -0
- msprobe/test/core_ut/test_common_config.py +152 -0
- msprobe/test/core_ut/test_file_check.py +218 -0
- msprobe/test/core_ut/test_log.py +109 -0
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
- msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
- msprobe/test/mindspore_ut/test_ms_config.py +69 -0
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
- msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
- msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
- msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
- msprobe/test/pytorch_ut/test_pt_config.py +69 -0
- msprobe/test/pytorch_ut/test_service.py +59 -0
- msprobe/test/resources/advisor.txt +3 -0
- msprobe/test/resources/compare_result_20230703104808.csv +9 -0
- msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
- msprobe/test/resources/config.yaml +3 -0
- msprobe/test/resources/npu_test.pkl +8 -0
- msprobe/test/run_test.sh +30 -0
- msprobe/test/run_ut.py +58 -0
- msprobe/test/test_module_processer.py +64 -0
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# coding=utf-8
|
|
2
|
+
import os
|
|
3
|
+
import copy
|
|
4
|
+
import unittest
|
|
5
|
+
import torch
|
|
6
|
+
from unittest.mock import patch, DEFAULT
|
|
7
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import *
|
|
8
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents
|
|
9
|
+
|
|
10
|
+
base_dir = os.path.dirname(os.path.realpath(__file__))
|
|
11
|
+
forward_file = os.path.join(base_dir, "forward.json")
|
|
12
|
+
forward_content = get_json_contents(forward_file)
|
|
13
|
+
for api_full_name, api_info_dict in forward_content.items():
|
|
14
|
+
api_full_name = api_full_name
|
|
15
|
+
api_info_dict = api_info_dict
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TestRunUtMethods(unittest.TestCase):
|
|
19
|
+
def test_exec_api(self):
|
|
20
|
+
api_info = copy.deepcopy(api_info_dict)
|
|
21
|
+
|
|
22
|
+
[api_type, api_name, _, _] = api_full_name.split(".")
|
|
23
|
+
args, kwargs, need_grad = get_api_info(api_info, api_name, None)
|
|
24
|
+
cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, True, '')
|
|
25
|
+
out = exec_api(api_type, api_name, cpu_args, cpu_kwargs)
|
|
26
|
+
self.assertEqual(out[0].dtype, torch.float32)
|
|
27
|
+
self.assertTrue(out[0].requires_grad)
|
|
28
|
+
self.assertEqual(out[0].shape, torch.Size([2048, 2, 1, 128]))
|
|
29
|
+
|
|
30
|
+
def test_generate_device_params(self):
|
|
31
|
+
mock_tensor = torch.rand([2, 2560, 24, 24], dtype=torch.float32, requires_grad=True)
|
|
32
|
+
|
|
33
|
+
with patch.multiple('torch.Tensor',
|
|
34
|
+
to=DEFAULT,
|
|
35
|
+
clone=DEFAULT,
|
|
36
|
+
detach=DEFAULT,
|
|
37
|
+
requires_grad_=DEFAULT,
|
|
38
|
+
type_as=DEFAULT,
|
|
39
|
+
retain_grad=DEFAULT) as mocks:
|
|
40
|
+
mocks['clone'].return_value = mock_tensor
|
|
41
|
+
mocks['detach'].return_value = mock_tensor
|
|
42
|
+
mocks['requires_grad_'].return_value = mock_tensor
|
|
43
|
+
mocks['type_as'].return_value = mock_tensor
|
|
44
|
+
mocks['retain_grad'].return_value = None
|
|
45
|
+
mocks['to'].return_value = mock_tensor
|
|
46
|
+
|
|
47
|
+
device_args, device_kwargs = generate_device_params([mock_tensor], {'inplace': False}, True, '')
|
|
48
|
+
self.assertEqual(len(device_args), 1)
|
|
49
|
+
self.assertEqual(device_args[0].dtype, torch.float32)
|
|
50
|
+
self.assertTrue(device_args[0].requires_grad)
|
|
51
|
+
self.assertEqual(device_args[0].shape, torch.Size([2, 2560, 24, 24]))
|
|
52
|
+
self.assertEqual(device_kwargs, {'inplace': False})
|
|
53
|
+
|
|
54
|
+
def test_generate_cpu_params(self):
|
|
55
|
+
api_info = copy.deepcopy(api_info_dict)
|
|
56
|
+
[api_type, api_name, _, _] = api_full_name.split(".")
|
|
57
|
+
args, kwargs, need_grad = get_api_info(api_info, api_name, None)
|
|
58
|
+
cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, True, '')
|
|
59
|
+
self.assertEqual(len(cpu_args), 2)
|
|
60
|
+
self.assertEqual(cpu_args[0].dtype, torch.float32)
|
|
61
|
+
self.assertTrue(cpu_args[0].requires_grad)
|
|
62
|
+
self.assertEqual(cpu_args[0].shape, torch.Size([2048, 2, 1, 256]))
|
|
63
|
+
self.assertEqual(cpu_kwargs, {'dim': -1})
|
|
64
|
+
|
|
65
|
+
def test_UtDataInfo(self):
|
|
66
|
+
data_info = UtDataInfo(None, None, None, None, None, None, None)
|
|
67
|
+
self.assertIsNone(data_info.bench_grad)
|
|
68
|
+
self.assertIsNone(data_info.device_grad)
|
|
69
|
+
self.assertIsNone(data_info.device_output)
|
|
70
|
+
self.assertIsNone(data_info.bench_output)
|
|
71
|
+
self.assertIsNone(data_info.grad_in)
|
|
72
|
+
self.assertIsNone(data_info.in_fwd_data_list)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# coding=utf-8
|
|
2
|
+
import unittest
|
|
3
|
+
from msprobe.pytorch.compare.acc_compare import rename_api
|
|
4
|
+
|
|
5
|
+
class TestUtilsMethods(unittest.TestCase):
|
|
6
|
+
|
|
7
|
+
def test_rename_api(self):
|
|
8
|
+
test_name_1 = "Distributed.broadcast.0.forward.input.0"
|
|
9
|
+
expect_name_1 = "Distributed.broadcast.input.0"
|
|
10
|
+
actual_name_1 = rename_api(test_name_1, "forward")
|
|
11
|
+
self.assertEqual(actual_name_1, expect_name_1)
|
|
12
|
+
|
|
13
|
+
test_name_2 = "Torch.sum.0.backward.output.0"
|
|
14
|
+
expect_name_2 = "Torch.sum.output.0"
|
|
15
|
+
actual_name_2 = rename_api(test_name_2, "backward")
|
|
16
|
+
self.assertEqual(actual_name_2, expect_name_2)
|
|
17
|
+
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from unittest import TestCase
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from msprobe.core.common.const import Const
|
|
5
|
+
from msprobe.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode
|
|
6
|
+
from msprobe.pytorch.free_benchmark.common.params import data_pre_deal
|
|
7
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TestPerturbedLayer(TestCase):
|
|
11
|
+
|
|
12
|
+
# 对输出精度和输入精度一致算子使用升精度扰动因子时, 输出结果的精度也会提升
|
|
13
|
+
def test_improve_precision_layer_handle_with_out_dtype_changing(self):
|
|
14
|
+
api_name = "Torch.mul.0.forward"
|
|
15
|
+
x = torch.randn(2, 3, dtype=torch.float16)
|
|
16
|
+
y = torch.randn(2, 3, dtype=torch.float16)
|
|
17
|
+
out = torch.mul(x, y)
|
|
18
|
+
|
|
19
|
+
data_params = data_pre_deal(api_name, torch.mul, (x, y), {})
|
|
20
|
+
data_params.fuzz_stage = Const.FORWARD
|
|
21
|
+
data_params.original_result = out
|
|
22
|
+
|
|
23
|
+
layer = LayerFactory.create(
|
|
24
|
+
api_name, DeviceType.NPU, PerturbationMode.IMPROVE_PRECISION
|
|
25
|
+
)
|
|
26
|
+
layer.handle(data_params)
|
|
27
|
+
self.assertEqual(data_params.original_result.dtype, torch.float16)
|
|
28
|
+
self.assertEqual(layer.perturbed_value, torch.float32)
|
|
29
|
+
self.assertEqual(data_params.perturbed_result.dtype, torch.float32)
|
|
30
|
+
|
|
31
|
+
# 对于可迭代类型的输入, 升精度方法会遍历其中元素对支持类型输入升精度
|
|
32
|
+
def test_improve_precision_layer_with_iterable_inputs(self):
|
|
33
|
+
api_name = "iterable.0.forward"
|
|
34
|
+
tensor_a = torch.randn(2, 3, dtype=torch.bfloat16)
|
|
35
|
+
tensor_b = torch.randn(2, 3, dtype=torch.float16)
|
|
36
|
+
tensor_c = torch.randn(2, 3, dtype=torch.float32)
|
|
37
|
+
tensor_d = torch.randn(2, 3, dtype=torch.float64)
|
|
38
|
+
tensor_f = torch.randn(2, 3, dtype=torch.float64).to(torch.int32)
|
|
39
|
+
inputs = [tensor_a, tensor_b, {"c": tensor_c, "d": tensor_d}, tensor_f]
|
|
40
|
+
|
|
41
|
+
layer = LayerFactory.create(
|
|
42
|
+
api_name, DeviceType.NPU, PerturbationMode.IMPROVE_PRECISION
|
|
43
|
+
)
|
|
44
|
+
Perturbed_value = layer.improve_tensor_precision(inputs)
|
|
45
|
+
self.assertEqual(Perturbed_value[0].dtype, torch.float32)
|
|
46
|
+
self.assertEqual(Perturbed_value[1].dtype, torch.float32)
|
|
47
|
+
self.assertEqual(Perturbed_value[2]["c"].dtype, torch.float32)
|
|
48
|
+
self.assertEqual(Perturbed_value[2]["d"].dtype, torch.float64)
|
|
49
|
+
self.assertEqual(Perturbed_value[3].dtype, torch.int32)
|
|
50
|
+
|
|
51
|
+
# no_change扰动因子不会改变输入
|
|
52
|
+
def test_no_change_layer(self):
|
|
53
|
+
api_name = "nochange.0.forward"
|
|
54
|
+
inputs = torch.as_tensor([1e-9, 1e-2], dtype=torch.float32)
|
|
55
|
+
layer = LayerFactory.create(
|
|
56
|
+
api_name, DeviceType.NPU, PerturbationMode.NO_CHANGE
|
|
57
|
+
)
|
|
58
|
+
Perturbed_value = layer.no_change(inputs)
|
|
59
|
+
self.assertEqual(Perturbed_value[0], 1e-9)
|
|
60
|
+
self.assertEqual(Perturbed_value[1], 1e-2)
|
|
61
|
+
|
|
62
|
+
# 对于一维二维张量,change_value扰动因子会交换首尾值的位置
|
|
63
|
+
def test_change_value_layer(self):
|
|
64
|
+
api_name = "change.0.forward"
|
|
65
|
+
inputs_1dim = torch.as_tensor([1e-9, 1e-7, 1e-2], dtype=torch.float32)
|
|
66
|
+
inputs_2dim = torch.as_tensor(
|
|
67
|
+
[[1e-9, 1e-7, 1e-2], [1e-9, 1e-2, 1e-7]], dtype=torch.float32
|
|
68
|
+
)
|
|
69
|
+
layer = LayerFactory.create(
|
|
70
|
+
api_name, DeviceType.NPU, PerturbationMode.CHANGE_VALUE
|
|
71
|
+
)
|
|
72
|
+
Perturbed_value_1dim = layer.change_value(inputs_1dim)
|
|
73
|
+
layer.is_added = False
|
|
74
|
+
Perturbed_value_2dim = layer.change_value(inputs_2dim)
|
|
75
|
+
self.assertEqual(Perturbed_value_1dim[0], 1e-2)
|
|
76
|
+
self.assertEqual(Perturbed_value_1dim[2], 1e-9)
|
|
77
|
+
self.assertEqual(Perturbed_value_2dim[0][0], 1e-7)
|
|
78
|
+
self.assertEqual(Perturbed_value_2dim[-1][-1], 1e-9)
|
|
79
|
+
|
|
80
|
+
# 对于输入张量,bit_noise扰动因子对大于极小值的部分进行末尾比特翻转
|
|
81
|
+
def test_bit_noise_layer(self):
|
|
82
|
+
api_name = "bitnoise.0.forward"
|
|
83
|
+
inputs = torch.as_tensor(
|
|
84
|
+
[4096.00048828125, 16777216, 1e-38], dtype=torch.float32
|
|
85
|
+
)
|
|
86
|
+
layer = LayerFactory.create(
|
|
87
|
+
api_name, DeviceType.NPU, PerturbationMode.BIT_NOISE
|
|
88
|
+
)
|
|
89
|
+
Perturbed_value = layer.add_bit_noise(inputs)
|
|
90
|
+
self.assertEqual(Perturbed_value[0], 4096.0000000000)
|
|
91
|
+
self.assertEqual(Perturbed_value[1], 16777218)
|
|
92
|
+
self.assertEqual(Perturbed_value[2], 1e-38)
|
|
93
|
+
|
|
94
|
+
# 对于输入张量,add_noise扰动因子对大于极小值的部分增加一个小值
|
|
95
|
+
def test_add_noise_layer(self):
|
|
96
|
+
api_name = "addnoise.0.forward"
|
|
97
|
+
inputs = torch.as_tensor(
|
|
98
|
+
[1e-1, 1e-2], dtype=torch.bfloat16
|
|
99
|
+
)
|
|
100
|
+
layer = LayerFactory.create(
|
|
101
|
+
api_name, DeviceType.NPU, PerturbationMode.ADD_NOISE
|
|
102
|
+
)
|
|
103
|
+
Perturbed_value = layer.add_noise(inputs)
|
|
104
|
+
self.assertEqual(Perturbed_value[0], 1e-1+1e-4)
|
|
105
|
+
self.assertEqual(Perturbed_value[1], 1e-2)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from unittest import TestCase
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from msprobe.core.common.const import Const
|
|
6
|
+
from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig, ThresholdConfig
|
|
7
|
+
from msprobe.pytorch.free_benchmark.common.counter import preheat_counter
|
|
8
|
+
from msprobe.pytorch.free_benchmark.common.enums import (
|
|
9
|
+
DeviceType,
|
|
10
|
+
FuzzLevel,
|
|
11
|
+
HandlerType,
|
|
12
|
+
PerturbationMode,
|
|
13
|
+
)
|
|
14
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams, make_handler_params
|
|
15
|
+
from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import (
|
|
16
|
+
FuzzHandlerFactory,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Config(ABC):
|
|
21
|
+
"""
|
|
22
|
+
用以提供参数配置
|
|
23
|
+
"""
|
|
24
|
+
def __init__(self, handler_type, preheat_config):
|
|
25
|
+
self.fuzz_stage = Const.FORWARD
|
|
26
|
+
self.handler_type = handler_type
|
|
27
|
+
self.fuzz_device = DeviceType.NPU
|
|
28
|
+
self.fuzz_level = FuzzLevel.BASE_LEVEL
|
|
29
|
+
self.pert_mode = PerturbationMode.IMPROVE_PRECISION
|
|
30
|
+
self.preheat_config = preheat_config
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TestFuzzHandler(TestCase):
|
|
34
|
+
|
|
35
|
+
def setUp(self) -> None:
|
|
36
|
+
origin_inputs = [
|
|
37
|
+
torch.as_tensor([3.01, 3.02], dtype=torch.float16),
|
|
38
|
+
torch.as_tensor([0.02, 0.02], dtype=torch.float16),
|
|
39
|
+
]
|
|
40
|
+
# 将输入乘以一个大于误差阈值1.002的值,模拟二次执行出现误差
|
|
41
|
+
perturbed_inputs = [
|
|
42
|
+
(value * 1.0021).to(torch.float32).to("cpu") for value in origin_inputs
|
|
43
|
+
]
|
|
44
|
+
origin_output = torch.add(*origin_inputs)
|
|
45
|
+
perturbed_output = torch.add(*perturbed_inputs)
|
|
46
|
+
# 实例有问题的data对象
|
|
47
|
+
self.data_params = DataParams(
|
|
48
|
+
args=origin_inputs,
|
|
49
|
+
kwargs={},
|
|
50
|
+
original_result=origin_output,
|
|
51
|
+
perturbed_result=perturbed_output,
|
|
52
|
+
origin_func=torch.add,
|
|
53
|
+
)
|
|
54
|
+
self.api_name = "add.0.forward"
|
|
55
|
+
self.step = 0
|
|
56
|
+
|
|
57
|
+
def test_result_handler_check(self):
|
|
58
|
+
# 对于check处理类,扰动前后输出不一致的情况会有UnequalRow对象生成
|
|
59
|
+
for _ in range(2):
|
|
60
|
+
config = Config(
|
|
61
|
+
HandlerType.CHECK, {PreheatConfig.IF_PREHEAT: False}
|
|
62
|
+
)
|
|
63
|
+
handler_params = make_handler_params(self.api_name, config, self.step)
|
|
64
|
+
handler = FuzzHandlerFactory.create(handler_params)
|
|
65
|
+
handler.handle(self.data_params)
|
|
66
|
+
self.assertEqual(
|
|
67
|
+
len(handler.get_unequal_rows()), 1
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def test_result_handler_fix(self):
|
|
71
|
+
# 对于fix处理类,扰动后输出会替代原始输出, dtype和原始输出一致,但值为新输出值
|
|
72
|
+
config = Config(
|
|
73
|
+
HandlerType.FIX, {PreheatConfig.IF_PREHEAT: False}
|
|
74
|
+
)
|
|
75
|
+
handler_params = make_handler_params(self.api_name, config, self.step)
|
|
76
|
+
handler = FuzzHandlerFactory.create(handler_params)
|
|
77
|
+
result = handler.handle(self.data_params)
|
|
78
|
+
self.assertEqual(result.dtype, torch.float16)
|
|
79
|
+
self.assertEqual(result.device, self.data_params.original_result.device)
|
|
80
|
+
self.assertAlmostEqual(
|
|
81
|
+
result[0], self.data_params.perturbed_result.to(torch.float16)[0]
|
|
82
|
+
)
|
|
83
|
+
self.assertAlmostEqual(
|
|
84
|
+
result[1], self.data_params.perturbed_result.to(torch.float16)[1]
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def test_result_handler_preheat(self):
|
|
88
|
+
# 对于preheat处理类,在预热阶段后的阈值会根据CPU调整
|
|
89
|
+
config = Config(
|
|
90
|
+
HandlerType.CHECK,
|
|
91
|
+
{
|
|
92
|
+
PreheatConfig.IF_PREHEAT: True,
|
|
93
|
+
PreheatConfig.PREHEAT_STEP: 4,
|
|
94
|
+
PreheatConfig.MAX_SAMPLE: 3
|
|
95
|
+
}
|
|
96
|
+
)
|
|
97
|
+
for _ in range(3):
|
|
98
|
+
handler_params = make_handler_params(self.api_name, config, 0)
|
|
99
|
+
handler = FuzzHandlerFactory.create(handler_params)
|
|
100
|
+
handler.handle(self.data_params)
|
|
101
|
+
# 通过preheat_counter的数据可以判断预热是否正常执行,这里第一个step会记录api执行次数
|
|
102
|
+
self.assertEqual(preheat_counter.get_one_step_used_api("add"), 3)
|
|
103
|
+
for step in range(1, 4):
|
|
104
|
+
for _ in range(3):
|
|
105
|
+
handler_params = make_handler_params(self.api_name, config, step)
|
|
106
|
+
handler = FuzzHandlerFactory.create(handler_params)
|
|
107
|
+
handler.handle(self.data_params)
|
|
108
|
+
# call time记录当前step api的调用次数
|
|
109
|
+
self.assertEqual(preheat_counter.get_api_called_time("add"), 3)
|
|
110
|
+
# 对于3个step最多采样三次的预热设置,sample time应该每次采样一例
|
|
111
|
+
self.assertEqual(preheat_counter.get_api_sample_time("add"), 1)
|
|
112
|
+
# 预热阶段,api阈值应该在两个阈值超参之间
|
|
113
|
+
api_threshld = preheat_counter.get_api_thd("add", "torch.float16")
|
|
114
|
+
self.assertLessEqual(
|
|
115
|
+
api_threshld,
|
|
116
|
+
ThresholdConfig.PREHEAT_INITIAL_THD
|
|
117
|
+
)
|
|
118
|
+
self.assertGreaterEqual(
|
|
119
|
+
api_threshld,
|
|
120
|
+
ThresholdConfig.DTYPE_PER_THD[torch.float16]
|
|
121
|
+
)
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from abc import ABC
|
|
3
|
+
from unittest import TestCase
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
from msprobe.core.common.const import Const
|
|
8
|
+
from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck
|
|
9
|
+
from msprobe.pytorch.free_benchmark.common.constant import CommonField, PreheatConfig
|
|
10
|
+
from msprobe.pytorch.free_benchmark.common.enums import (
|
|
11
|
+
DeviceType,
|
|
12
|
+
FuzzLevel,
|
|
13
|
+
HandlerType,
|
|
14
|
+
PerturbationMode,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Config(ABC):
|
|
19
|
+
"""
|
|
20
|
+
用以提供参数配置
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, fuzz_stage, handler_type):
|
|
24
|
+
self.fuzz_stage = fuzz_stage
|
|
25
|
+
self.handler_type = handler_type
|
|
26
|
+
self.fuzz_device = DeviceType.NPU
|
|
27
|
+
self.fuzz_level = FuzzLevel.BASE_LEVEL
|
|
28
|
+
self.pert_mode = PerturbationMode.IMPROVE_PRECISION
|
|
29
|
+
self.preheat_config = {PreheatConfig.IF_PREHEAT: False}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class WrapMul(nn.Module):
|
|
33
|
+
"""
|
|
34
|
+
用nn.module包装mul算子, 在forward中调用torch.mul
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, op_name) -> None:
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.op_name = op_name
|
|
40
|
+
|
|
41
|
+
def forward(self, *args, **kwargs):
|
|
42
|
+
return torch.mul(*args, **kwargs)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class UnequalDataProcessor(ABC):
|
|
46
|
+
"""
|
|
47
|
+
接口类, 处理检测不一致结果
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self) -> None:
|
|
51
|
+
super().__init__()
|
|
52
|
+
self.unequal_rows = []
|
|
53
|
+
|
|
54
|
+
def update_unequal_rows(self, unequal_rows):
|
|
55
|
+
self.unequal_rows.append(unequal_rows)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class TestInterface(TestCase):
|
|
59
|
+
def setUp(self):
|
|
60
|
+
self.api_name = "Torch.mul.0"
|
|
61
|
+
|
|
62
|
+
def testForwardFix(self):
|
|
63
|
+
# 对于前向接口,在forward钩子中开启FIX,返回结果给hook的输出
|
|
64
|
+
config = Config(Const.FORWARD, HandlerType.FIX)
|
|
65
|
+
checker = FreeBenchmarkCheck(config)
|
|
66
|
+
# 执行算子前向
|
|
67
|
+
x = torch.randn(2, 3).to(torch.float16)
|
|
68
|
+
y = torch.randn(2, 3).to(torch.float16)
|
|
69
|
+
mul_module = WrapMul(self.api_name)
|
|
70
|
+
out = mul_module(x, y)
|
|
71
|
+
# 模拟forward hook中调用无标杆前向检测接口
|
|
72
|
+
result, _ = checker.forward(
|
|
73
|
+
self.api_name,
|
|
74
|
+
mul_module,
|
|
75
|
+
args=(x, y),
|
|
76
|
+
kwargs={},
|
|
77
|
+
output=out,
|
|
78
|
+
)
|
|
79
|
+
self.assertEqual(result.dtype, torch.float32)
|
|
80
|
+
|
|
81
|
+
def testBackwardCheck(self):
|
|
82
|
+
# 对于反向接口,在pre forward时暂存input, 然后在backwrad后进行对比
|
|
83
|
+
config = Config(Const.BACKWARD, HandlerType.CHECK)
|
|
84
|
+
checker = FreeBenchmarkCheck(config)
|
|
85
|
+
processor = UnequalDataProcessor()
|
|
86
|
+
# 初始化输入输出
|
|
87
|
+
x = torch.tensor([2, 3], dtype=torch.float16, requires_grad=True)
|
|
88
|
+
y = torch.tensor([2, 3], dtype=torch.float16, requires_grad=True)
|
|
89
|
+
grad_output = torch.tensor([1,1], dtype=torch.float16)
|
|
90
|
+
backward_name = Const.SEP.join([self.api_name, Const.BACKWARD])
|
|
91
|
+
# 执行前向生成grad saver实例
|
|
92
|
+
mul_module = WrapMul(self.api_name)
|
|
93
|
+
checker.pre_forward(backward_name, mul_module, processor, (x, y), {})
|
|
94
|
+
# 执行算子前向和反向, 并反向获取扰动后grad_input
|
|
95
|
+
out = mul_module(x, y)
|
|
96
|
+
checker.backward(backward_name, mul_module, grad_output)
|
|
97
|
+
out.backward(torch.ones_like(out))
|
|
98
|
+
# module是否添加暂存器, 其中反向钩子执行扰动后grad_input是否正确
|
|
99
|
+
self.assertTrue(hasattr(mul_module, CommonField.GRADSAVER))
|
|
100
|
+
grad_saver = getattr(mul_module, CommonField.GRADSAVER)
|
|
101
|
+
self.assertEqual(grad_saver.perturbed_grad_input[0][0], 2)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from msprobe.pytorch import PrecisionDebugger
|
|
5
|
+
from msprobe.pytorch.functional.dump_module import module_dump, module_count
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TestDumpModule(unittest.TestCase):
|
|
9
|
+
def setUp(self):
|
|
10
|
+
self.module = nn.Linear(in_features=8, out_features=4)
|
|
11
|
+
|
|
12
|
+
def test_module_dump(self):
|
|
13
|
+
PrecisionDebugger(dump_path="./dump")
|
|
14
|
+
module_dump(self.module, "TestModule")
|
|
15
|
+
self.assertTrue("TestModule" in module_count)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
from msprobe.pytorch.hook_module.api_registry import ApiRegistry, torch_version_above_2, is_gpu
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TestApiRegistry(unittest.TestCase):
|
|
6
|
+
|
|
7
|
+
def test_store_ori_attr(self):
|
|
8
|
+
class A():
|
|
9
|
+
a1 = 1
|
|
10
|
+
class B():
|
|
11
|
+
a = A()
|
|
12
|
+
b1 = 1
|
|
13
|
+
b2 = 2
|
|
14
|
+
|
|
15
|
+
api_list = ["a.a1", "b1", "b2"]
|
|
16
|
+
expect_output = {"a.a1":1, "b1":1, "b2":2}
|
|
17
|
+
actual_output = dict()
|
|
18
|
+
ApiRegistry.store_ori_attr(B, api_list, actual_output)
|
|
19
|
+
self.assertEqual(actual_output, expect_output)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_set_api_attr(self):
|
|
23
|
+
class A():
|
|
24
|
+
a1 = 1
|
|
25
|
+
class B():
|
|
26
|
+
a = A().__class__
|
|
27
|
+
b1 = 1
|
|
28
|
+
|
|
29
|
+
attr_dict = {"a.a2":2, "b2":2, "b3":3}
|
|
30
|
+
ApiRegistry.set_api_attr(B, attr_dict)
|
|
31
|
+
|
|
32
|
+
for k, v in attr_dict.items():
|
|
33
|
+
if '.' in k:
|
|
34
|
+
sub_module_name, sub_op = k.rsplit('.', 1)
|
|
35
|
+
sub_module = getattr(B, sub_module_name, None)
|
|
36
|
+
|
|
37
|
+
self.assertEqual(getattr(sub_module, sub_op), v)
|
|
38
|
+
else:
|
|
39
|
+
self.assertEqual(getattr(B, k), v)
|
|
40
|
+
|
|
41
|
+
def test_api_modularity(self):
|
|
42
|
+
|
|
43
|
+
import torch
|
|
44
|
+
import torch.distributed as dist
|
|
45
|
+
#import torch_npu #门禁没有安装torch_npu
|
|
46
|
+
from msprobe.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
reg = ApiRegistry()
|
|
51
|
+
attr_dict = {"b2":2, "b3":3}
|
|
52
|
+
reg.tensor_hook_attr = attr_dict
|
|
53
|
+
reg.torch_hook_attr = attr_dict
|
|
54
|
+
reg.functional_hook_attr = attr_dict
|
|
55
|
+
reg.distributed_hook_attr = attr_dict
|
|
56
|
+
reg.npu_distributed_hook_attr = attr_dict
|
|
57
|
+
reg.aten_hook_attr = attr_dict
|
|
58
|
+
reg.vf_hook_attr = attr_dict
|
|
59
|
+
reg.torch_npu_hook_attr = attr_dict
|
|
60
|
+
|
|
61
|
+
reg.api_modularity()
|
|
62
|
+
self.assertEqual(torch.Tensor.b2, 2)
|
|
63
|
+
|
|
64
|
+
self.assertEqual(torch.b2, 2)
|
|
65
|
+
self.assertEqual(torch.nn.functional.b2, 2)
|
|
66
|
+
self.assertEqual(dist.b2, 2)
|
|
67
|
+
self.assertEqual(dist.distributed_c10d.b2, 2)
|
|
68
|
+
#if not is_gpu and not torch_without_guard_version:
|
|
69
|
+
#self.assertEqual(torch_npu.distributed.b2, 2)
|
|
70
|
+
#self.assertEqual(torch_npu.distributed.distributed_c10d.b2, 2)
|
|
71
|
+
if torch_version_above_2:
|
|
72
|
+
self.assertEqual(torch.ops.aten.b2, 2)
|
|
73
|
+
self.assertEqual(torch._VF.b2, 2)
|
|
74
|
+
#if not is_gpu:
|
|
75
|
+
#self.assertEqual(torch_npu.b2, 2)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def test_api_originality(self):
|
|
79
|
+
import torch
|
|
80
|
+
import torch.distributed as dist
|
|
81
|
+
#import torch_npu #门禁没有安装torch_npu
|
|
82
|
+
from msprobe.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
reg = ApiRegistry()
|
|
87
|
+
attr_dict = {"b2":2, "b3":3}
|
|
88
|
+
reg.tensor_hook_attr = attr_dict
|
|
89
|
+
reg.torch_hook_attr = attr_dict
|
|
90
|
+
reg.functional_hook_attr = attr_dict
|
|
91
|
+
reg.distributed_hook_attr = attr_dict
|
|
92
|
+
reg.npu_distributed_hook_attr = attr_dict
|
|
93
|
+
reg.aten_hook_attr = attr_dict
|
|
94
|
+
reg.vf_hook_attr = attr_dict
|
|
95
|
+
reg.torch_npu_hook_attr = attr_dict
|
|
96
|
+
|
|
97
|
+
reg.api_originality()
|
|
98
|
+
self.assertEqual(torch.Tensor.b2, 2)
|
|
99
|
+
|
|
100
|
+
self.assertEqual(torch.b2, 2)
|
|
101
|
+
self.assertEqual(torch.nn.functional.b2, 2)
|
|
102
|
+
self.assertEqual(dist.b2, 2)
|
|
103
|
+
self.assertEqual(dist.distributed_c10d.b2, 2)
|
|
104
|
+
#if not is_gpu and not torch_without_guard_version:
|
|
105
|
+
#self.assertEqual(torch_npu.distributed.b2, 2)
|
|
106
|
+
#self.assertEqual(torch_npu.distributed.distributed_c10d.b2, 2)
|
|
107
|
+
if torch_version_above_2:
|
|
108
|
+
self.assertEqual(torch.ops.aten.b2, 2)
|
|
109
|
+
self.assertEqual(torch._VF.b2, 2)
|
|
110
|
+
#if not is_gpu:
|
|
111
|
+
#self.assertEqual(torch_npu.b2, 2)
|
|
112
|
+
|
|
113
|
+
def test_initialize_hook(self):
|
|
114
|
+
def hook_test():
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
reg = ApiRegistry()
|
|
118
|
+
reg.initialize_hook(hook_test)
|
|
119
|
+
empty_list = []
|
|
120
|
+
self.assertFalse(empty_list==reg.tensor_hook_attr)
|
|
121
|
+
self.assertFalse(empty_list==reg.torch_hook_attr)
|
|
122
|
+
self.assertFalse(empty_list==reg.functional_hook_attr)
|
|
123
|
+
self.assertFalse(empty_list==reg.distributed_hook_attr)
|
|
124
|
+
self.assertFalse(empty_list==reg.npu_distributed_hook_attr)
|
|
125
|
+
if torch_version_above_2:
|
|
126
|
+
#print(True)
|
|
127
|
+
self.assertFalse(empty_list==reg.aten_hook_attr)
|
|
128
|
+
if not is_gpu:
|
|
129
|
+
#print(True)
|
|
130
|
+
self.assertFalse(empty_list==reg.torch_npu_hook_attr)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
from unittest.mock import patch, Mock
|
|
3
|
+
|
|
4
|
+
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
5
|
+
|
|
6
|
+
class TestHookModule(unittest.TestCase):
|
|
7
|
+
def test_call_1(self):
|
|
8
|
+
def forward_pre_hook():
|
|
9
|
+
return "result_input", "result_kwargs"
|
|
10
|
+
def forward_hook():
|
|
11
|
+
return 2
|
|
12
|
+
def backward_hook():
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
def hook(prefix):
|
|
16
|
+
return forward_pre_hook, forward_hook, backward_hook
|
|
17
|
+
HOOKModule.prefix_op_name_ = "123"
|
|
18
|
+
test = HOOKModule(hook)
|
|
19
|
+
test._call_func = Mock(return_value=1)
|
|
20
|
+
result = test()
|
|
21
|
+
self.assertEqual(result, 1)
|
|
22
|
+
|
|
23
|
+
def test_call_2(self):
|
|
24
|
+
def forward_pre_hook(nope, input, kwargs):
|
|
25
|
+
return input, kwargs
|
|
26
|
+
def forward_hook(nope, input, kwargs, result):
|
|
27
|
+
return input
|
|
28
|
+
def backward_hook():
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
def hook(prefix):
|
|
32
|
+
return forward_pre_hook, forward_hook, backward_hook
|
|
33
|
+
HOOKModule.prefix_op_name_ = "123"
|
|
34
|
+
input = 2
|
|
35
|
+
test = HOOKModule(hook)
|
|
36
|
+
|
|
37
|
+
def temp_forward(*input, **kwargs):
|
|
38
|
+
return input
|
|
39
|
+
|
|
40
|
+
test.forward = Mock(return_value=1)
|
|
41
|
+
result = test(input)
|
|
42
|
+
self.assertEqual(result, (input, ))
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
import torch
|
|
3
|
+
from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate, AtenOPPacketTemplate
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def hook(name):
|
|
7
|
+
def forward_pre_hook(nope, input, kwargs):
|
|
8
|
+
return input, kwargs
|
|
9
|
+
def forward_hook(nope, input, kwargs, result):
|
|
10
|
+
return 2
|
|
11
|
+
def backward_hook():
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
return forward_pre_hook, forward_hook, backward_hook
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TestWrapAten(unittest.TestCase):
|
|
19
|
+
def setUp(self):
|
|
20
|
+
self.aten_op = AtenOPPacketTemplate(torch.ops.aten.convolution, hook)
|
|
21
|
+
|
|
22
|
+
def test_atenop_attribute(self):
|
|
23
|
+
if torch.__version__.split("+")[0] <= '2.0':
|
|
24
|
+
return
|
|
25
|
+
self.setUp()
|
|
26
|
+
self.assertEqual(self.aten_op.default.op, torch.ops.aten.convolution.default)
|
|
27
|
+
self.assertEqual(self.aten_op.out.op, torch.ops.aten.convolution.out)
|
|
28
|
+
|
|
29
|
+
def test_atenop_forward(self):
|
|
30
|
+
if torch.__version__.split("+")[0] <= '2.0':
|
|
31
|
+
return
|
|
32
|
+
self.setUp()
|
|
33
|
+
image = torch.randn(4, 3, 24, 24)
|
|
34
|
+
kernel = torch.randn(10, 3, 3, 3)
|
|
35
|
+
functional_out = torch.nn.functional.conv2d(image, kernel, stride=[1, 1],
|
|
36
|
+
padding=[1, 1], dilation=[1, 1], groups=1, bias=None)
|
|
37
|
+
aten_out = self.aten_op(image, kernel, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)
|
|
38
|
+
self.assertTrue(aten_out == 2)
|
|
39
|
+
|
|
40
|
+
def test_atenop_overload_forward(self):
|
|
41
|
+
if torch.__version__.split("+")[0] <= '2.0':
|
|
42
|
+
return
|
|
43
|
+
self.setUp()
|
|
44
|
+
image = torch.randn(4, 3, 24, 24)
|
|
45
|
+
kernel = torch.randn(10, 3, 3, 3)
|
|
46
|
+
functional_out = torch.nn.functional.conv2d(image, kernel, stride=[1, 1],
|
|
47
|
+
padding=[1, 1], dilation=[1, 1], groups=1, bias=None)
|
|
48
|
+
aten_out = self.aten_op(image, kernel, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)
|
|
49
|
+
self.assertTrue(aten_out == 2)
|
|
50
|
+
|
|
51
|
+
def test_atenop_nonattr(self):
|
|
52
|
+
if torch.__version__.split("+")[0] <= '2.0':
|
|
53
|
+
return
|
|
54
|
+
self.setUp()
|
|
55
|
+
self.assertRaises(AttributeError, getattr, self.aten_op, "foo")
|
|
56
|
+
|
|
57
|
+
def test_atenop_overloads(self):
|
|
58
|
+
if torch.__version__.split("+")[0] <= '2.0':
|
|
59
|
+
return
|
|
60
|
+
self.setUp()
|
|
61
|
+
self.assertEqual(self.aten_op.overloads(), self.aten_op.opPacket.overloads())
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
|