mindstudio-probe 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
- mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
- mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
- mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
- mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
- mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
- msprobe/README.md +182 -0
- msprobe/__init__.py +0 -0
- msprobe/config/README.md +397 -0
- msprobe/config/config.json +28 -0
- msprobe/config/img/free_benchmark.png +0 -0
- msprobe/core/common/const.py +241 -0
- msprobe/core/common/exceptions.py +88 -0
- msprobe/core/common/file_check.py +265 -0
- msprobe/core/common/log.py +55 -0
- msprobe/core/common/utils.py +516 -0
- msprobe/core/common_config.py +58 -0
- msprobe/core/data_dump/data_collector.py +140 -0
- msprobe/core/data_dump/data_processor/base.py +245 -0
- msprobe/core/data_dump/data_processor/factory.py +61 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
- msprobe/core/data_dump/json_writer.py +116 -0
- msprobe/core/data_dump/scope.py +178 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/debugger/__init__.py +0 -0
- msprobe/mindspore/debugger/debugger_config.py +51 -0
- msprobe/mindspore/debugger/precision_debugger.py +32 -0
- msprobe/mindspore/doc/dump.md +65 -0
- msprobe/mindspore/dump/__init__.py +0 -0
- msprobe/mindspore/dump/api_kbk_dump.py +55 -0
- msprobe/mindspore/dump/dump_tool_factory.py +38 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
- msprobe/mindspore/ms_config.py +78 -0
- msprobe/mindspore/overflow_check/__init__.py +0 -0
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
- msprobe/mindspore/task_handler_factory.py +21 -0
- msprobe/msprobe.py +67 -0
- msprobe/pytorch/__init__.py +4 -0
- msprobe/pytorch/advisor/advisor.py +124 -0
- msprobe/pytorch/advisor/advisor_const.py +59 -0
- msprobe/pytorch/advisor/advisor_result.py +58 -0
- msprobe/pytorch/api_accuracy_checker/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
- msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
- msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
- msprobe/pytorch/common/__init__.py +2 -0
- msprobe/pytorch/common/compare_script.template +14 -0
- msprobe/pytorch/common/log.py +32 -0
- msprobe/pytorch/common/parse_json.py +37 -0
- msprobe/pytorch/common/utils.py +224 -0
- msprobe/pytorch/compare/acc_compare.py +1024 -0
- msprobe/pytorch/compare/distributed_compare.py +111 -0
- msprobe/pytorch/compare/highlight.py +100 -0
- msprobe/pytorch/compare/mapping.yaml +607 -0
- msprobe/pytorch/compare/match.py +36 -0
- msprobe/pytorch/compare/npy_compare.py +244 -0
- msprobe/pytorch/debugger/__init__.py +0 -0
- msprobe/pytorch/debugger/debugger_config.py +86 -0
- msprobe/pytorch/debugger/precision_debugger.py +95 -0
- msprobe/pytorch/doc/FAQ.md +193 -0
- msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
- msprobe/pytorch/doc/dump.md +207 -0
- msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
- msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
- msprobe/pytorch/doc/img/cpu_info.png +0 -0
- msprobe/pytorch/doc/img/module_compare.png +0 -0
- msprobe/pytorch/doc/parse_tool.md +286 -0
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
- msprobe/pytorch/doc/run_overflow_check.md +25 -0
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
- msprobe/pytorch/free_benchmark/__init__.py +8 -0
- msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/common/constant.py +67 -0
- msprobe/pytorch/free_benchmark/common/counter.py +72 -0
- msprobe/pytorch/free_benchmark/common/enums.py +37 -0
- msprobe/pytorch/free_benchmark/common/params.py +129 -0
- msprobe/pytorch/free_benchmark/common/utils.py +98 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
- msprobe/pytorch/free_benchmark/main.py +102 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
- msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
- msprobe/pytorch/functional/__init__.py +0 -0
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +39 -0
- msprobe/pytorch/hook_module/__init__.py +1 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -0
- msprobe/pytorch/hook_module/hook_module.py +109 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
- msprobe/pytorch/hook_module/utils.py +29 -0
- msprobe/pytorch/hook_module/wrap_aten.py +100 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
- msprobe/pytorch/hook_module/wrap_functional.py +108 -0
- msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
- msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
- msprobe/pytorch/hook_module/wrap_torch.py +88 -0
- msprobe/pytorch/hook_module/wrap_vf.py +64 -0
- msprobe/pytorch/module_processer.py +98 -0
- msprobe/pytorch/online_dispatch/__init__.py +20 -0
- msprobe/pytorch/online_dispatch/compare.py +236 -0
- msprobe/pytorch/online_dispatch/dispatch.py +274 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
- msprobe/pytorch/online_dispatch/single_compare.py +391 -0
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
- msprobe/pytorch/online_dispatch/utils.py +187 -0
- msprobe/pytorch/parse.py +4 -0
- msprobe/pytorch/parse_tool/__init__.py +0 -0
- msprobe/pytorch/parse_tool/cli.py +32 -0
- msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
- msprobe/pytorch/parse_tool/lib/compare.py +259 -0
- msprobe/pytorch/parse_tool/lib/config.py +51 -0
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
- msprobe/pytorch/parse_tool/lib/utils.py +367 -0
- msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
- msprobe/pytorch/pt_config.py +93 -0
- msprobe/pytorch/service.py +167 -0
- msprobe/test/core_ut/common/test_utils.py +345 -0
- msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
- msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
- msprobe/test/core_ut/data_dump/test_scope.py +151 -0
- msprobe/test/core_ut/test_common_config.py +152 -0
- msprobe/test/core_ut/test_file_check.py +218 -0
- msprobe/test/core_ut/test_log.py +109 -0
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
- msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
- msprobe/test/mindspore_ut/test_ms_config.py +69 -0
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
- msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
- msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
- msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
- msprobe/test/pytorch_ut/test_pt_config.py +69 -0
- msprobe/test/pytorch_ut/test_service.py +59 -0
- msprobe/test/resources/advisor.txt +3 -0
- msprobe/test/resources/compare_result_20230703104808.csv +9 -0
- msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
- msprobe/test/resources/config.yaml +3 -0
- msprobe/test/resources/npu_test.pkl +8 -0
- msprobe/test/run_test.sh +30 -0
- msprobe/test/run_ut.py +58 -0
- msprobe/test/test_module_processer.py +64 -0
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from msprobe.core.common_config import CommonConfig, BaseConfig
|
|
5
|
+
from msprobe.core.common.file_check import FileOpen
|
|
6
|
+
from msprobe.core.common.const import Const
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TensorConfig(BaseConfig):
|
|
10
|
+
def __init__(self, json_config):
|
|
11
|
+
super().__init__(json_config)
|
|
12
|
+
self.check_config()
|
|
13
|
+
self._check_file_format()
|
|
14
|
+
|
|
15
|
+
def _check_file_format(self):
|
|
16
|
+
if self.file_format is not None and self.file_format not in ["npy", "bin"]:
|
|
17
|
+
raise Exception("file_format is invalid")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class StatisticsConfig(BaseConfig):
|
|
21
|
+
def __init__(self, json_config):
|
|
22
|
+
super().__init__(json_config)
|
|
23
|
+
self.check_config()
|
|
24
|
+
self._check_summary_mode()
|
|
25
|
+
|
|
26
|
+
def _check_summary_mode(self):
|
|
27
|
+
if self.summary_mode and self.summary_mode not in ["statistics", "md5"]:
|
|
28
|
+
raise Exception("summary_mode is invalid")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class OverflowCheckConfig(BaseConfig):
|
|
32
|
+
def __init__(self, json_config):
|
|
33
|
+
super().__init__(json_config)
|
|
34
|
+
self.overflow_num = json_config.get("overflow_nums")
|
|
35
|
+
self.check_mode = json_config.get("check_mode")
|
|
36
|
+
self.check_overflow_config()
|
|
37
|
+
|
|
38
|
+
def check_overflow_config(self):
|
|
39
|
+
if self.overflow_num is not None and not isinstance(self.overflow_num, int):
|
|
40
|
+
raise Exception("overflow_num is invalid")
|
|
41
|
+
if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]:
|
|
42
|
+
raise Exception("check_mode is invalid")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class FreeBenchmarkCheckConfig(BaseConfig):
|
|
46
|
+
def __init__(self, json_config):
|
|
47
|
+
super().__init__(json_config)
|
|
48
|
+
self.fuzz_device = json_config.get("fuzz_device")
|
|
49
|
+
self.pert_mode = json_config.get("pert_mode")
|
|
50
|
+
self.handler_type = json_config.get("handler_type")
|
|
51
|
+
self.fuzz_level = json_config.get("fuzz_level")
|
|
52
|
+
self.fuzz_stage = json_config.get("fuzz_stage")
|
|
53
|
+
self.if_preheat = json_config.get("if_preheat")
|
|
54
|
+
self.preheat_step = json_config.get("preheat_step")
|
|
55
|
+
self.max_sample = json_config.get("max_sample")
|
|
56
|
+
self.check_freebenchmark_config()
|
|
57
|
+
|
|
58
|
+
def check_freebenchmark_config(self):
|
|
59
|
+
if self.if_preheat and self.handler_type == "fix":
|
|
60
|
+
raise Exception("Preheating is not supported in fix handler type")
|
|
61
|
+
if self.preheat_step and self.preheat_step == 0:
|
|
62
|
+
raise Exception("preheat_step cannot be 0")
|
|
63
|
+
|
|
64
|
+
def parse_task_config(task, json_config):
|
|
65
|
+
default_dic = {}
|
|
66
|
+
if task == Const.TENSOR:
|
|
67
|
+
config_dic = json_config.get(Const.TENSOR) if json_config.get(Const.TENSOR) else default_dic
|
|
68
|
+
return TensorConfig(config_dic)
|
|
69
|
+
elif task == Const.STATISTICS:
|
|
70
|
+
config_dic = json_config.get(Const.STATISTICS) if json_config.get(Const.STATISTICS) else default_dic
|
|
71
|
+
return StatisticsConfig(config_dic)
|
|
72
|
+
elif task == Const.OVERFLOW_CHECK:
|
|
73
|
+
config_dic = json_config.get(Const.OVERFLOW_CHECK) if json_config.get(Const.OVERFLOW_CHECK) else default_dic
|
|
74
|
+
return OverflowCheckConfig(config_dic)
|
|
75
|
+
elif task == Const.FREE_BENCHMARK:
|
|
76
|
+
config_dic = json_config.get(Const.FREE_BENCHMARK) if json_config.get(Const.FREE_BENCHMARK) else default_dic
|
|
77
|
+
return FreeBenchmarkCheckConfig(config_dic)
|
|
78
|
+
else:
|
|
79
|
+
return StatisticsConfig(default_dic)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def parse_json_config(json_file_path, task):
|
|
83
|
+
if not json_file_path:
|
|
84
|
+
config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
85
|
+
json_file_path = os.path.join(os.path.join(config_dir, "config"), "config.json")
|
|
86
|
+
with FileOpen(json_file_path, 'r') as file:
|
|
87
|
+
json_config = json.load(file)
|
|
88
|
+
common_config = CommonConfig(json_config)
|
|
89
|
+
if task and task in Const.TASK_LIST:
|
|
90
|
+
task_config = parse_task_config(task, json_config)
|
|
91
|
+
else:
|
|
92
|
+
task_config = parse_task_config(common_config.task, json_config)
|
|
93
|
+
return common_config, task_config
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from msprobe.pytorch.common.log import logger
|
|
6
|
+
from msprobe.core.common.file_check import FileChecker, check_path_before_create
|
|
7
|
+
from msprobe.core.common.const import Const, FileCheckConst
|
|
8
|
+
from msprobe.core.common.exceptions import DistributedNotInitializedError, MsaccException
|
|
9
|
+
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
10
|
+
from msprobe.core.data_dump.scope import BaseScope
|
|
11
|
+
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
12
|
+
from msprobe.pytorch.common.utils import get_rank_if_initialized
|
|
13
|
+
from msprobe.pytorch.module_processer import ModuleProcesser
|
|
14
|
+
from msprobe.pytorch.hook_module import remove_dropout
|
|
15
|
+
from msprobe.pytorch.hook_module.api_registry import api_register
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Service:
|
|
19
|
+
def __init__(self, config):
|
|
20
|
+
self.model = None
|
|
21
|
+
self.config = config
|
|
22
|
+
self.data_collector = build_data_collector(config)
|
|
23
|
+
self.module_processor = ModuleProcesser(self.data_collector.scope)
|
|
24
|
+
self.switch = False
|
|
25
|
+
self.current_iter = 0
|
|
26
|
+
self.first_start = True
|
|
27
|
+
self.current_rank = None
|
|
28
|
+
self.dump_iter_dir = None
|
|
29
|
+
|
|
30
|
+
def build_hook(self, module_type, name):
|
|
31
|
+
def pre_hook(api_or_module_name, module, args, kwargs):
|
|
32
|
+
if module_type == BaseScope.Module_Type_Module:
|
|
33
|
+
api_or_module_name = module.mindstudio_reserved_name
|
|
34
|
+
self.data_collector.visit_and_clear_overflow_status(api_or_module_name)
|
|
35
|
+
|
|
36
|
+
if not self.switch:
|
|
37
|
+
return args, kwargs
|
|
38
|
+
if self.data_collector:
|
|
39
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
|
|
40
|
+
self.data_collector.pre_forward_data_collect(api_or_module_name, module, pid, module_input_output)
|
|
41
|
+
return args, kwargs
|
|
42
|
+
|
|
43
|
+
def forward_hook(api_or_module_name, module, args, kwargs, output):
|
|
44
|
+
if module_type == BaseScope.Module_Type_Module:
|
|
45
|
+
api_or_module_name = module.mindstudio_reserved_name
|
|
46
|
+
self.data_collector.visit_and_clear_overflow_status(api_or_module_name)
|
|
47
|
+
|
|
48
|
+
if not self.switch:
|
|
49
|
+
return None
|
|
50
|
+
if self.data_collector:
|
|
51
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
|
|
52
|
+
self.data_collector.forward_data_collect(api_or_module_name, module, pid, module_input_output)
|
|
53
|
+
if self.data_collector.if_return_forward_new_output():
|
|
54
|
+
return self.data_collector.get_forward_new_output()
|
|
55
|
+
return output
|
|
56
|
+
|
|
57
|
+
def backward_hook(api_or_module_name, module, grad_input, grad_output):
|
|
58
|
+
if module_type == BaseScope.Module_Type_Module:
|
|
59
|
+
api_or_module_name = module.mindstudio_reserved_name
|
|
60
|
+
self.data_collector.visit_and_clear_overflow_status(api_or_module_name)
|
|
61
|
+
|
|
62
|
+
if not self.switch:
|
|
63
|
+
return
|
|
64
|
+
if self.data_collector:
|
|
65
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
|
|
66
|
+
self.data_collector.backward_data_collect(api_or_module_name, module, pid, module_input_output)
|
|
67
|
+
|
|
68
|
+
pid = os.getpid()
|
|
69
|
+
forward_name_template = name + Const.FORWARD
|
|
70
|
+
backward_name_template = name + Const.BACKWARD
|
|
71
|
+
pre_forward_hook = functools.partial(pre_hook, forward_name_template)
|
|
72
|
+
forward_hook = functools.partial(forward_hook, forward_name_template)
|
|
73
|
+
backward_hook = functools.partial(backward_hook, backward_name_template)
|
|
74
|
+
return pre_forward_hook, forward_hook, backward_hook
|
|
75
|
+
|
|
76
|
+
def step(self):
|
|
77
|
+
self.current_iter += 1
|
|
78
|
+
self.data_collector.update_iter(self.current_iter)
|
|
79
|
+
|
|
80
|
+
def start(self, model):
|
|
81
|
+
self.model = model
|
|
82
|
+
if self.config.step and self.current_iter > max(self.config.step):
|
|
83
|
+
self.stop()
|
|
84
|
+
raise Exception("msprobe: exit after iteration {}".format(max(self.config.step)))
|
|
85
|
+
if self.config.step and self.current_iter not in self.config.step:
|
|
86
|
+
return
|
|
87
|
+
if self.first_start:
|
|
88
|
+
try:
|
|
89
|
+
self.current_rank = get_rank_if_initialized()
|
|
90
|
+
except DistributedNotInitializedError:
|
|
91
|
+
self.current_rank = None
|
|
92
|
+
|
|
93
|
+
if self.config.rank and self.current_rank not in self.config.rank:
|
|
94
|
+
return
|
|
95
|
+
self.register_hook_new()
|
|
96
|
+
self.first_start = False
|
|
97
|
+
self.switch = True
|
|
98
|
+
logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ")
|
|
99
|
+
if self.config.level != "L2":
|
|
100
|
+
self.create_dirs()
|
|
101
|
+
logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
|
|
102
|
+
|
|
103
|
+
def stop(self):
|
|
104
|
+
if self.config.level == "L2":
|
|
105
|
+
return
|
|
106
|
+
if self.config.step and self.current_iter not in self.config.step:
|
|
107
|
+
return
|
|
108
|
+
if self.config.rank and self.current_rank not in self.config.rank:
|
|
109
|
+
return
|
|
110
|
+
self.switch = False
|
|
111
|
+
self.data_collector.write_json()
|
|
112
|
+
|
|
113
|
+
def create_dirs(self):
|
|
114
|
+
check_path_before_create(self.config.dump_path)
|
|
115
|
+
if not os.path.exists(self.config.dump_path):
|
|
116
|
+
Path(self.config.dump_path).mkdir(mode=0o750, exist_ok=True)
|
|
117
|
+
file_check = FileChecker(self.config.dump_path, FileCheckConst.DIR)
|
|
118
|
+
file_check.common_check()
|
|
119
|
+
self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
|
|
120
|
+
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
121
|
+
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
122
|
+
if not os.path.exists(dump_dir):
|
|
123
|
+
Path(dump_dir).mkdir(mode=0o750, parents=True, exist_ok=True)
|
|
124
|
+
if self.config.task in self.data_collector.tasks_need_tensor_data:
|
|
125
|
+
dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
|
|
126
|
+
Path(dump_data_dir).mkdir(mode=0o750, exist_ok=True)
|
|
127
|
+
else:
|
|
128
|
+
dump_data_dir = None
|
|
129
|
+
|
|
130
|
+
dump_file_path = os.path.join(dump_dir, "dump.json")
|
|
131
|
+
stack_file_path = os.path.join(dump_dir, "stack.json")
|
|
132
|
+
construct_file_path = os.path.join(dump_dir, "construct.json")
|
|
133
|
+
free_benchmark_file_path = os.path.join(self.config.dump_path, "free_benchmark.csv")
|
|
134
|
+
self.data_collector.update_dump_paths(
|
|
135
|
+
dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path)
|
|
136
|
+
|
|
137
|
+
def register_hook_new(self):
|
|
138
|
+
logger.info_on_rank_0("The {} hook function is successfully mounted to the model.".format(self.config.task))
|
|
139
|
+
if self.config.level in ["L0", "mix"]:
|
|
140
|
+
if self.model is None:
|
|
141
|
+
logger.error_log_with_exp("The model is None.", MsaccException.INVALID_PARAM_ERROR)
|
|
142
|
+
logger.info_on_rank_0("The init dump mode is enabled, and the module dump function will not be available")
|
|
143
|
+
for name, module in self.model.named_modules():
|
|
144
|
+
if module == self.model:
|
|
145
|
+
continue
|
|
146
|
+
prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP + \
|
|
147
|
+
module.__class__.__name__ + Const.SEP
|
|
148
|
+
|
|
149
|
+
pre_forward_hook, forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix)
|
|
150
|
+
module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
151
|
+
module.register_full_backward_hook(backward_hook)
|
|
152
|
+
|
|
153
|
+
module.register_forward_pre_hook(
|
|
154
|
+
self.module_processor.node_hook(prefix + Const.FORWARD, Const.START))
|
|
155
|
+
module.register_forward_hook(
|
|
156
|
+
self.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
|
|
157
|
+
module.register_full_backward_pre_hook(
|
|
158
|
+
self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
|
|
159
|
+
module.register_full_backward_hook(
|
|
160
|
+
self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
161
|
+
|
|
162
|
+
if self.config.level in ["mix", "L1", "L2"]:
|
|
163
|
+
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
|
|
164
|
+
api_register.api_modularity()
|
|
165
|
+
|
|
166
|
+
if Const.STATISTICS == self.config.task or Const.TENSOR == self.config.task:
|
|
167
|
+
remove_dropout()
|
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
import os
|
|
18
|
+
import uuid
|
|
19
|
+
|
|
20
|
+
from unittest import TestCase
|
|
21
|
+
from unittest.mock import patch, MagicMock, mock_open
|
|
22
|
+
|
|
23
|
+
from msprobe.core.common.log import logger
|
|
24
|
+
from msprobe.core.common.const import Const
|
|
25
|
+
from msprobe.core.common.utils import (CompareException,
|
|
26
|
+
check_seed_all,
|
|
27
|
+
check_inplace_op,
|
|
28
|
+
make_dump_path_if_not_exists,
|
|
29
|
+
check_mode_valid,
|
|
30
|
+
check_switch_valid,
|
|
31
|
+
check_dump_mode_valid,
|
|
32
|
+
check_summary_mode_valid,
|
|
33
|
+
check_summary_only_valid,
|
|
34
|
+
check_file_or_directory_path,
|
|
35
|
+
check_compare_param,
|
|
36
|
+
check_configuration_param,
|
|
37
|
+
is_starts_with,
|
|
38
|
+
_check_json,
|
|
39
|
+
check_json_file,
|
|
40
|
+
check_file_size,
|
|
41
|
+
check_regex_prefix_format_valid,
|
|
42
|
+
get_dump_data_path,
|
|
43
|
+
task_dumppath_get)
|
|
44
|
+
from msprobe.core.common.file_check import FileCheckConst
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class TestUtils(TestCase):
|
|
48
|
+
@patch.object(logger, "error")
|
|
49
|
+
def test_check_seed_all(self, mock_error):
|
|
50
|
+
self.assertIsNone(check_seed_all(1234, True))
|
|
51
|
+
self.assertIsNone(check_seed_all(0, True))
|
|
52
|
+
self.assertIsNone(check_seed_all(Const.MAX_SEED_VALUE, True))
|
|
53
|
+
|
|
54
|
+
with self.assertRaises(CompareException) as context:
|
|
55
|
+
check_seed_all(-1, True)
|
|
56
|
+
self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
|
|
57
|
+
mock_error.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
|
|
58
|
+
|
|
59
|
+
with self.assertRaises(CompareException) as context:
|
|
60
|
+
check_seed_all(Const.MAX_SEED_VALUE + 1, True)
|
|
61
|
+
self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
|
|
62
|
+
mock_error.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
|
|
63
|
+
|
|
64
|
+
with self.assertRaises(CompareException) as context:
|
|
65
|
+
check_seed_all("1234", True)
|
|
66
|
+
self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
|
|
67
|
+
mock_error.assert_called_with("Seed must be integer.")
|
|
68
|
+
|
|
69
|
+
with self.assertRaises(CompareException) as context:
|
|
70
|
+
check_seed_all(1234, 1)
|
|
71
|
+
self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
|
|
72
|
+
mock_error.assert_called_with("seed_all mode must be bool.")
|
|
73
|
+
|
|
74
|
+
def test_check_inplace_op(self):
|
|
75
|
+
test_prefix_1 = "Distributed.broadcast.0.forward.input.0"
|
|
76
|
+
self.assertTrue(check_inplace_op(test_prefix_1))
|
|
77
|
+
test_prefix_2 = "Distributed_broadcast_0_forward_input_0"
|
|
78
|
+
self.assertFalse(check_inplace_op(test_prefix_2))
|
|
79
|
+
test_prefix_3 = "Torch.sum.0.backward.output.0"
|
|
80
|
+
self.assertFalse(check_inplace_op(test_prefix_3))
|
|
81
|
+
|
|
82
|
+
@patch.object(logger, "error")
|
|
83
|
+
def test_make_dump_path_if_not_exists(self, mock_error):
|
|
84
|
+
file_path = os.path.realpath(__file__)
|
|
85
|
+
dirname = os.path.dirname(file_path) + str(uuid.uuid4())
|
|
86
|
+
|
|
87
|
+
def test_mkdir(self, **kwargs):
|
|
88
|
+
raise OSError
|
|
89
|
+
|
|
90
|
+
if not os.path.exists(dirname):
|
|
91
|
+
with patch("msprobe.core.common.utils.Path.mkdir", new=test_mkdir):
|
|
92
|
+
with self.assertRaises(CompareException) as context:
|
|
93
|
+
make_dump_path_if_not_exists(dirname)
|
|
94
|
+
self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR)
|
|
95
|
+
|
|
96
|
+
make_dump_path_if_not_exists(file_path)
|
|
97
|
+
mock_error.assert_called_with(f"{file_path} already exists and is not a directory.")
|
|
98
|
+
|
|
99
|
+
def test_check_mode_valid(self):
|
|
100
|
+
with self.assertRaises(ValueError) as context:
|
|
101
|
+
check_mode_valid("all", scope="scope")
|
|
102
|
+
self.assertEqual(str(context.exception), "scope param set invalid, it's must be a list.")
|
|
103
|
+
|
|
104
|
+
with self.assertRaises(ValueError) as context:
|
|
105
|
+
check_mode_valid("all", api_list="api_list")
|
|
106
|
+
self.assertEqual(str(context.exception), "api_list param set invalid, it's must be a list.")
|
|
107
|
+
|
|
108
|
+
mode = "all_list"
|
|
109
|
+
with self.assertRaises(CompareException) as context:
|
|
110
|
+
check_mode_valid(mode)
|
|
111
|
+
self.assertEqual(context.exception.code, CompareException.INVALID_DUMP_MODE)
|
|
112
|
+
self.assertEqual(str(context.exception),
|
|
113
|
+
f"Current mode '{mode}' is not supported. Please use the field in {Const.DUMP_MODE}")
|
|
114
|
+
|
|
115
|
+
mode = "list"
|
|
116
|
+
with self.assertRaises(ValueError) as context:
|
|
117
|
+
check_mode_valid(mode)
|
|
118
|
+
self.assertEqual(str(context.exception),
|
|
119
|
+
"set_dump_switch, scope param set invalid, it's should not be an empty list.")
|
|
120
|
+
|
|
121
|
+
@patch.object(logger, "error")
|
|
122
|
+
def test_check_switch_valid(self, mock_error):
|
|
123
|
+
with self.assertRaises(CompareException) as context:
|
|
124
|
+
check_switch_valid("Close")
|
|
125
|
+
self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
|
|
126
|
+
mock_error.assert_called_with("Please set switch with 'ON' or 'OFF'.")
|
|
127
|
+
|
|
128
|
+
@patch.object(logger, "warning")
|
|
129
|
+
def test_check_dump_mode_valid(self, mock_warning):
|
|
130
|
+
dump_mode = check_dump_mode_valid("all")
|
|
131
|
+
mock_warning.assert_called_with("Please set dump_mode as a list.")
|
|
132
|
+
self.assertEqual(dump_mode, ["forward", "backward", "input", "output"])
|
|
133
|
+
|
|
134
|
+
with self.assertRaises(ValueError) as context:
|
|
135
|
+
check_dump_mode_valid("all_forward")
|
|
136
|
+
self.assertEqual(str(context.exception),
|
|
137
|
+
"Please set dump_mode as a list containing one or more of the following: " +
|
|
138
|
+
"'all', 'forward', 'backward', 'input', 'output'.")
|
|
139
|
+
|
|
140
|
+
def test_check_summary_mode_valid(self):
|
|
141
|
+
with self.assertRaises(CompareException) as context:
|
|
142
|
+
check_summary_mode_valid("MD5")
|
|
143
|
+
self.assertEqual(context.exception.code, CompareException.INVALID_SUMMARY_MODE)
|
|
144
|
+
self.assertEqual(str(context.exception), "The summary_mode is not valid")
|
|
145
|
+
|
|
146
|
+
@patch.object(logger, "error")
|
|
147
|
+
def test_check_summary_only_valid(self, mock_error):
|
|
148
|
+
summary_only = check_summary_only_valid(True)
|
|
149
|
+
self.assertTrue(summary_only)
|
|
150
|
+
|
|
151
|
+
with self.assertRaises(CompareException) as context:
|
|
152
|
+
check_summary_only_valid("True")
|
|
153
|
+
self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
|
|
154
|
+
mock_error.assert_called_with("Params summary_only only support True or False.")
|
|
155
|
+
|
|
156
|
+
def test_check_file_or_directory_path(self):
|
|
157
|
+
class TestFileChecker:
|
|
158
|
+
file_path = ""
|
|
159
|
+
path_type = ""
|
|
160
|
+
ability = ""
|
|
161
|
+
checked = False
|
|
162
|
+
|
|
163
|
+
def __init__(self, file_path, path_type, ability=None):
|
|
164
|
+
TestFileChecker.file_path = file_path
|
|
165
|
+
TestFileChecker.path_type = path_type
|
|
166
|
+
TestFileChecker.ability = ability
|
|
167
|
+
|
|
168
|
+
def common_check(self):
|
|
169
|
+
TestFileChecker.checked = True
|
|
170
|
+
|
|
171
|
+
file_path = os.path.realpath(__file__)
|
|
172
|
+
dirname = os.path.dirname(file_path)
|
|
173
|
+
|
|
174
|
+
with patch("msprobe.core.common.utils.FileChecker", new=TestFileChecker):
|
|
175
|
+
check_file_or_directory_path(file_path, isdir=False)
|
|
176
|
+
self.assertTrue(TestFileChecker.checked)
|
|
177
|
+
self.assertEqual(TestFileChecker.file_path, file_path)
|
|
178
|
+
self.assertEqual(TestFileChecker.path_type, FileCheckConst.FILE)
|
|
179
|
+
self.assertEqual(TestFileChecker.ability, FileCheckConst.READ_ABLE)
|
|
180
|
+
|
|
181
|
+
TestFileChecker.checked = False
|
|
182
|
+
with patch("msprobe.core.common.utils.FileChecker", new=TestFileChecker):
|
|
183
|
+
check_file_or_directory_path(dirname, isdir=True)
|
|
184
|
+
self.assertTrue(TestFileChecker.checked)
|
|
185
|
+
self.assertEqual(TestFileChecker.file_path, dirname)
|
|
186
|
+
self.assertEqual(TestFileChecker.path_type, FileCheckConst.DIR)
|
|
187
|
+
self.assertEqual(TestFileChecker.ability, FileCheckConst.WRITE_ABLE)
|
|
188
|
+
|
|
189
|
+
@patch.object(logger, "error")
|
|
190
|
+
def test_check_compare_param(self, mock_error):
|
|
191
|
+
params = {
|
|
192
|
+
"npu_json_path": "npu_json_path",
|
|
193
|
+
"bench_json_path": "bench_json_path",
|
|
194
|
+
"stack_json_path": "stack_json_path",
|
|
195
|
+
"npu_dump_data_dir": "npu_dump_data_dir",
|
|
196
|
+
"bench_dump_data_dir": "bench_dump_data_dir"
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
call_args = [
|
|
200
|
+
("npu_json_path", False),
|
|
201
|
+
("bench_json_path", False),
|
|
202
|
+
("stack_json_path", False),
|
|
203
|
+
("npu_dump_data_dir", True),
|
|
204
|
+
("bench_dump_data_dir", True),
|
|
205
|
+
("output_path", True),
|
|
206
|
+
("npu_json_path", False),
|
|
207
|
+
("bench_json_path", False),
|
|
208
|
+
("stack_json_path", False),
|
|
209
|
+
("output_path", True)
|
|
210
|
+
]
|
|
211
|
+
|
|
212
|
+
with self.assertRaises(CompareException) as context:
|
|
213
|
+
check_compare_param("npu_json_path", "output_path")
|
|
214
|
+
self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
|
|
215
|
+
mock_error.assert_called_with("Invalid input parameters")
|
|
216
|
+
|
|
217
|
+
mock_check_file_or_directory_path = MagicMock()
|
|
218
|
+
mock_check_json_file = MagicMock()
|
|
219
|
+
with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \
|
|
220
|
+
patch("msprobe.core.common.utils.check_json_file", new=mock_check_json_file), \
|
|
221
|
+
patch("msprobe.core.common.utils.check_file_or_directory_path", new=mock_check_file_or_directory_path):
|
|
222
|
+
check_compare_param(params, "output_path")
|
|
223
|
+
check_compare_param(params, "output_path", summary_compare=False, md5_compare=True)
|
|
224
|
+
for i in range(len(call_args)):
|
|
225
|
+
self.assertEqual(mock_check_file_or_directory_path.call_args_list[i][0], call_args[i])
|
|
226
|
+
self.assertEqual(len(mock_check_json_file.call_args[0]), 4)
|
|
227
|
+
self.assertEqual(mock_check_json_file.call_args[0][0], params)
|
|
228
|
+
|
|
229
|
+
@patch.object(logger, "error")
|
|
230
|
+
def test_check_configuration_param(self, mock_error):
|
|
231
|
+
with self.assertRaises(CompareException) as context:
|
|
232
|
+
check_configuration_param(stack_mode="False", auto_analyze=True, fuzzy_match=False)
|
|
233
|
+
self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
|
|
234
|
+
mock_error.assert_called_with("Invalid input parameters which should be only bool type.")
|
|
235
|
+
|
|
236
|
+
def test_is_starts_with(self):
|
|
237
|
+
string = "input_slot0"
|
|
238
|
+
self.assertFalse(is_starts_with(string, []))
|
|
239
|
+
self.assertFalse(is_starts_with("", ["input"]))
|
|
240
|
+
self.assertFalse(is_starts_with(string, ["output"]))
|
|
241
|
+
self.assertTrue(is_starts_with(string, ["input", "output"]))
|
|
242
|
+
|
|
243
|
+
@patch.object(logger, "error")
|
|
244
|
+
def test__check_json(self, mock_error):
|
|
245
|
+
class TestOpen:
|
|
246
|
+
def __init__(self, string):
|
|
247
|
+
self.string = string
|
|
248
|
+
|
|
249
|
+
def readline(self):
|
|
250
|
+
return self.string
|
|
251
|
+
|
|
252
|
+
def seek(self, begin, end):
|
|
253
|
+
self.string = str(begin) + "_" + str(end)
|
|
254
|
+
|
|
255
|
+
with self.assertRaises(CompareException) as context:
|
|
256
|
+
_check_json(TestOpen(""), "test.json")
|
|
257
|
+
self.assertEqual(context.exception.code, CompareException.INVALID_DUMP_FILE)
|
|
258
|
+
mock_error.assert_called_with("dump file test.json have empty line!")
|
|
259
|
+
|
|
260
|
+
handler = TestOpen("jons file\n")
|
|
261
|
+
_check_json(handler, "test.json")
|
|
262
|
+
self.assertEqual(handler.string, "0_0")
|
|
263
|
+
|
|
264
|
+
@patch("msprobe.core.common.utils._check_json")
|
|
265
|
+
def test_check_json_file(self, _mock_check_json):
|
|
266
|
+
input_param = {
|
|
267
|
+
"npu_json_path": "npu_json_path",
|
|
268
|
+
"bench_json_path": "bench_json_path",
|
|
269
|
+
"stack_json_path": "stack_json_path"
|
|
270
|
+
}
|
|
271
|
+
check_json_file(input_param, "npu_json", "bench_json", "stack_json")
|
|
272
|
+
self.assertEqual(_mock_check_json.call_args_list[0][0], ("npu_json", "npu_json_path"))
|
|
273
|
+
self.assertEqual(_mock_check_json.call_args_list[1][0], ("bench_json", "bench_json_path"))
|
|
274
|
+
self.assertEqual(_mock_check_json.call_args_list[2][0], ("stack_json", "stack_json_path"))
|
|
275
|
+
|
|
276
|
+
@patch.object(logger, "error")
|
|
277
|
+
def test_check_file_size(self, mock_error):
|
|
278
|
+
with patch("msprobe.core.common.utils.os.path.getsize", return_value=120):
|
|
279
|
+
with self.assertRaises(CompareException) as context:
|
|
280
|
+
check_file_size("input_file", 100)
|
|
281
|
+
self.assertEqual(context.exception.code, CompareException.INVALID_FILE_ERROR)
|
|
282
|
+
mock_error.assert_called_with("The size (120) of input_file exceeds (100) bytes, tools not support.")
|
|
283
|
+
|
|
284
|
+
def test_check_regex_prefix_format_valid(self):
|
|
285
|
+
prefix = "A" * 21
|
|
286
|
+
with self.assertRaises(ValueError) as context:
|
|
287
|
+
check_regex_prefix_format_valid(prefix)
|
|
288
|
+
self.assertEqual(str(context.exception), f"Maximum length of prefix is {Const.REGEX_PREFIX_MAX_LENGTH}, "
|
|
289
|
+
f"while current length is {len(prefix)}")
|
|
290
|
+
|
|
291
|
+
prefix = "(prefix)"
|
|
292
|
+
with self.assertRaises(ValueError) as context:
|
|
293
|
+
check_regex_prefix_format_valid(prefix)
|
|
294
|
+
self.assertEqual(str(context.exception), f"prefix contains invalid characters, "
|
|
295
|
+
f"prefix pattern {Const.REGEX_PREFIX_PATTERN}")
|
|
296
|
+
|
|
297
|
+
@patch("msprobe.core.common.utils.check_file_or_directory_path")
|
|
298
|
+
def test_get_dump_data_path(self, mock_check_file_or_directory_path):
|
|
299
|
+
file_path = os.path.realpath(__file__)
|
|
300
|
+
dirname = os.path.dirname(file_path)
|
|
301
|
+
|
|
302
|
+
dump_data_path, file_is_exist = get_dump_data_path(dirname)
|
|
303
|
+
self.assertEqual(mock_check_file_or_directory_path.call_args[0], (dirname, True))
|
|
304
|
+
self.assertEqual(dump_data_path, dirname)
|
|
305
|
+
self.assertTrue(file_is_exist)
|
|
306
|
+
|
|
307
|
+
@patch.object(logger, "error")
|
|
308
|
+
def test_task_dumppath_get(self, mock_error):
|
|
309
|
+
input_param = {
|
|
310
|
+
"npu_json_path": None,
|
|
311
|
+
"bench_json_path": "bench_json_path"
|
|
312
|
+
}
|
|
313
|
+
npu_json = {
|
|
314
|
+
"task": Const.TENSOR,
|
|
315
|
+
"dump_data_dir": "dump_data_dir",
|
|
316
|
+
"data": "data"
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
with self.assertRaises(CompareException) as context:
|
|
320
|
+
task_dumppath_get(input_param)
|
|
321
|
+
self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR)
|
|
322
|
+
mock_error.assert_called_with("Please check the json path is valid.")
|
|
323
|
+
|
|
324
|
+
input_param["npu_json_path"] = "npu_json_path"
|
|
325
|
+
with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \
|
|
326
|
+
patch("msprobe.core.common.utils.json.load", return_value=npu_json):
|
|
327
|
+
summary_compare, md5_compare = task_dumppath_get(input_param)
|
|
328
|
+
self.assertFalse(summary_compare)
|
|
329
|
+
self.assertFalse(md5_compare)
|
|
330
|
+
|
|
331
|
+
npu_json["task"] = Const.STATISTICS
|
|
332
|
+
with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \
|
|
333
|
+
patch("msprobe.core.common.utils.json.load", return_value=npu_json), \
|
|
334
|
+
patch("msprobe.core.common.utils.md5_find", return_value=True):
|
|
335
|
+
summary_compare, md5_compare = task_dumppath_get(input_param)
|
|
336
|
+
self.assertFalse(summary_compare)
|
|
337
|
+
self.assertTrue(md5_compare)
|
|
338
|
+
|
|
339
|
+
npu_json["task"] = Const.OVERFLOW_CHECK
|
|
340
|
+
with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \
|
|
341
|
+
patch("msprobe.core.common.utils.json.load", return_value=npu_json):
|
|
342
|
+
with self.assertRaises(CompareException) as context:
|
|
343
|
+
task_dumppath_get(input_param)
|
|
344
|
+
self.assertEqual(context.exception.code, CompareException.INVALID_TASK_ERROR)
|
|
345
|
+
mock_error.assert_called_with("Compare is not required for overflow_check or free_benchmark.")
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
from unittest.mock import patch, mock_open, MagicMock
|
|
3
|
+
|
|
4
|
+
from msprobe.core.common.utils import Const
|
|
5
|
+
from msprobe.core.data_dump.data_collector import DataCollector
|
|
6
|
+
from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
|
|
7
|
+
from msprobe.pytorch.pt_config import parse_json_config
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TestDataCollector(unittest.TestCase):
|
|
11
|
+
def setUp(self):
|
|
12
|
+
mock_json_data = {
|
|
13
|
+
"dump_path": "./ut_dump",
|
|
14
|
+
}
|
|
15
|
+
with patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \
|
|
16
|
+
patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data):
|
|
17
|
+
common_config, task_config = parse_json_config("./config.json", Const.STATISTICS)
|
|
18
|
+
config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./ut_dump", "L1")
|
|
19
|
+
self.data_collector = DataCollector(config)
|
|
20
|
+
|
|
21
|
+
def test_update_data(self):
|
|
22
|
+
self.data_collector.config.task = Const.OVERFLOW_CHECK
|
|
23
|
+
self.data_collector.data_processor.has_overflow = True
|
|
24
|
+
with patch("msprobe.core.data_dump.json_writer.DataWriter.update_data", return_value=None):
|
|
25
|
+
result1 = self.data_collector.update_data("test message", "test1:")
|
|
26
|
+
self.assertEqual(result1, "test1:Overflow detected.")
|
|
27
|
+
|
|
28
|
+
self.data_collector.data_processor.has_overflow = False
|
|
29
|
+
result2 = self.data_collector.update_data("test message", "test2:")
|
|
30
|
+
self.assertEqual(result2, "test2:No Overflow, OK.")
|
|
31
|
+
|
|
32
|
+
self.data_collector.config.task = Const.STATISTICS
|
|
33
|
+
self.data_collector.data_processor.has_overflow = True
|
|
34
|
+
with patch("msprobe.core.data_dump.json_writer.DataWriter.update_data", return_value=None):
|
|
35
|
+
result3 = self.data_collector.update_data("test message", "test3")
|
|
36
|
+
self.assertEqual(result3, "test3")
|
|
37
|
+
|
|
38
|
+
def test_pre_forward_data_collect(self):
|
|
39
|
+
self.data_collector.check_scope_and_pid = MagicMock(return_value=False)
|
|
40
|
+
self.data_collector.is_inplace = MagicMock(return_value=False)
|
|
41
|
+
self.data_collector.data_processor.analyze_pre_forward = MagicMock()
|
|
42
|
+
name = "TestModule.forward"
|
|
43
|
+
pid = 123
|
|
44
|
+
|
|
45
|
+
self.data_collector.pre_forward_data_collect(name, None, pid, None)
|
|
46
|
+
self.data_collector.check_scope_and_pid.assert_called_once_with(
|
|
47
|
+
self.data_collector.scope, "TestModule.backward", 123)
|