mindstudio-probe 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
- mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
- mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
- mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
- mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
- mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
- msprobe/README.md +182 -0
- msprobe/__init__.py +0 -0
- msprobe/config/README.md +397 -0
- msprobe/config/config.json +28 -0
- msprobe/config/img/free_benchmark.png +0 -0
- msprobe/core/common/const.py +241 -0
- msprobe/core/common/exceptions.py +88 -0
- msprobe/core/common/file_check.py +265 -0
- msprobe/core/common/log.py +55 -0
- msprobe/core/common/utils.py +516 -0
- msprobe/core/common_config.py +58 -0
- msprobe/core/data_dump/data_collector.py +140 -0
- msprobe/core/data_dump/data_processor/base.py +245 -0
- msprobe/core/data_dump/data_processor/factory.py +61 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
- msprobe/core/data_dump/json_writer.py +116 -0
- msprobe/core/data_dump/scope.py +178 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/debugger/__init__.py +0 -0
- msprobe/mindspore/debugger/debugger_config.py +51 -0
- msprobe/mindspore/debugger/precision_debugger.py +32 -0
- msprobe/mindspore/doc/dump.md +65 -0
- msprobe/mindspore/dump/__init__.py +0 -0
- msprobe/mindspore/dump/api_kbk_dump.py +55 -0
- msprobe/mindspore/dump/dump_tool_factory.py +38 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
- msprobe/mindspore/ms_config.py +78 -0
- msprobe/mindspore/overflow_check/__init__.py +0 -0
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
- msprobe/mindspore/task_handler_factory.py +21 -0
- msprobe/msprobe.py +67 -0
- msprobe/pytorch/__init__.py +4 -0
- msprobe/pytorch/advisor/advisor.py +124 -0
- msprobe/pytorch/advisor/advisor_const.py +59 -0
- msprobe/pytorch/advisor/advisor_result.py +58 -0
- msprobe/pytorch/api_accuracy_checker/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
- msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
- msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
- msprobe/pytorch/common/__init__.py +2 -0
- msprobe/pytorch/common/compare_script.template +14 -0
- msprobe/pytorch/common/log.py +32 -0
- msprobe/pytorch/common/parse_json.py +37 -0
- msprobe/pytorch/common/utils.py +224 -0
- msprobe/pytorch/compare/acc_compare.py +1024 -0
- msprobe/pytorch/compare/distributed_compare.py +111 -0
- msprobe/pytorch/compare/highlight.py +100 -0
- msprobe/pytorch/compare/mapping.yaml +607 -0
- msprobe/pytorch/compare/match.py +36 -0
- msprobe/pytorch/compare/npy_compare.py +244 -0
- msprobe/pytorch/debugger/__init__.py +0 -0
- msprobe/pytorch/debugger/debugger_config.py +86 -0
- msprobe/pytorch/debugger/precision_debugger.py +95 -0
- msprobe/pytorch/doc/FAQ.md +193 -0
- msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
- msprobe/pytorch/doc/dump.md +207 -0
- msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
- msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
- msprobe/pytorch/doc/img/cpu_info.png +0 -0
- msprobe/pytorch/doc/img/module_compare.png +0 -0
- msprobe/pytorch/doc/parse_tool.md +286 -0
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
- msprobe/pytorch/doc/run_overflow_check.md +25 -0
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
- msprobe/pytorch/free_benchmark/__init__.py +8 -0
- msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/common/constant.py +67 -0
- msprobe/pytorch/free_benchmark/common/counter.py +72 -0
- msprobe/pytorch/free_benchmark/common/enums.py +37 -0
- msprobe/pytorch/free_benchmark/common/params.py +129 -0
- msprobe/pytorch/free_benchmark/common/utils.py +98 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
- msprobe/pytorch/free_benchmark/main.py +102 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
- msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
- msprobe/pytorch/functional/__init__.py +0 -0
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +39 -0
- msprobe/pytorch/hook_module/__init__.py +1 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -0
- msprobe/pytorch/hook_module/hook_module.py +109 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
- msprobe/pytorch/hook_module/utils.py +29 -0
- msprobe/pytorch/hook_module/wrap_aten.py +100 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
- msprobe/pytorch/hook_module/wrap_functional.py +108 -0
- msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
- msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
- msprobe/pytorch/hook_module/wrap_torch.py +88 -0
- msprobe/pytorch/hook_module/wrap_vf.py +64 -0
- msprobe/pytorch/module_processer.py +98 -0
- msprobe/pytorch/online_dispatch/__init__.py +20 -0
- msprobe/pytorch/online_dispatch/compare.py +236 -0
- msprobe/pytorch/online_dispatch/dispatch.py +274 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
- msprobe/pytorch/online_dispatch/single_compare.py +391 -0
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
- msprobe/pytorch/online_dispatch/utils.py +187 -0
- msprobe/pytorch/parse.py +4 -0
- msprobe/pytorch/parse_tool/__init__.py +0 -0
- msprobe/pytorch/parse_tool/cli.py +32 -0
- msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
- msprobe/pytorch/parse_tool/lib/compare.py +259 -0
- msprobe/pytorch/parse_tool/lib/config.py +51 -0
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
- msprobe/pytorch/parse_tool/lib/utils.py +367 -0
- msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
- msprobe/pytorch/pt_config.py +93 -0
- msprobe/pytorch/service.py +167 -0
- msprobe/test/core_ut/common/test_utils.py +345 -0
- msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
- msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
- msprobe/test/core_ut/data_dump/test_scope.py +151 -0
- msprobe/test/core_ut/test_common_config.py +152 -0
- msprobe/test/core_ut/test_file_check.py +218 -0
- msprobe/test/core_ut/test_log.py +109 -0
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
- msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
- msprobe/test/mindspore_ut/test_ms_config.py +69 -0
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
- msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
- msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
- msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
- msprobe/test/pytorch_ut/test_pt_config.py +69 -0
- msprobe/test/pytorch_ut/test_service.py +59 -0
- msprobe/test/resources/advisor.txt +3 -0
- msprobe/test/resources/compare_result_20230703104808.csv +9 -0
- msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
- msprobe/test/resources/config.yaml +3 -0
- msprobe/test/resources/npu_test.pkl +8 -0
- msprobe/test/run_test.sh +30 -0
- msprobe/test/run_ut.py +58 -0
- msprobe/test/test_module_processer.py +64 -0
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from msprobe.mindspore.ms_config import parse_json_config
|
|
3
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
4
|
+
from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PrecisionDebugger:
|
|
8
|
+
_instance = None
|
|
9
|
+
|
|
10
|
+
def __new__(cls, config_path=None):
|
|
11
|
+
if not cls._instance:
|
|
12
|
+
cls._instance = super().__new__(cls)
|
|
13
|
+
cls._instance.initialized = False
|
|
14
|
+
cls._instance.config = None
|
|
15
|
+
return cls._instance
|
|
16
|
+
|
|
17
|
+
def __init__(self, config_path=None):
|
|
18
|
+
if self.initialized:
|
|
19
|
+
return
|
|
20
|
+
if not config_path:
|
|
21
|
+
config_path = os.path.join(os.path.dirname(__file__), "../../config/config.json")
|
|
22
|
+
common_config, task_config = parse_json_config(config_path)
|
|
23
|
+
self.config = DebuggerConfig(common_config, task_config)
|
|
24
|
+
self.initialized = True
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def start(cls, target=None):
|
|
28
|
+
instance = cls._instance
|
|
29
|
+
if not instance:
|
|
30
|
+
raise Exception("No instance of PrecisionDebugger found.")
|
|
31
|
+
handler = TaskHandlerFactory.create(instance.config)
|
|
32
|
+
handler.handle()
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# **精度数据采集**
|
|
2
|
+
|
|
3
|
+
msprobe工具主要通过在训练脚本内添加dump接口并启动训练的方式来采集精度数据。
|
|
4
|
+
|
|
5
|
+
执行dump操作需要安装msprobe工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
|
|
6
|
+
|
|
7
|
+
## dump接口介绍
|
|
8
|
+
|
|
9
|
+
### PrecisionDebugger
|
|
10
|
+
|
|
11
|
+
**功能说明**
|
|
12
|
+
|
|
13
|
+
通过加载dump配置文件的方式来确定dump操作的详细配置。
|
|
14
|
+
|
|
15
|
+
可以在from msprobe.mindspore import PrecisionDebugger和模型初始化之间的任意位置添加该接口。
|
|
16
|
+
|
|
17
|
+
**原型**
|
|
18
|
+
|
|
19
|
+
```Python
|
|
20
|
+
PrecisionDebugger(config_path=None)
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
**参数说明**
|
|
24
|
+
|
|
25
|
+
| 参数名 | 说明 | 是否必选 |
|
|
26
|
+
| ----------- | ------------------------------------------------------------ | -------- |
|
|
27
|
+
| config_path | 指定dump配置文件路径,String类型。参数示例:"./config.json"。未配置该路径时,默认使用[config.json](../../config)文件的默认配置。config.json文件可以配置更多参数,若需要进行更多场景的精度数据dump,建议配置[config.json](../../config/config.json)文件。 | 否 |
|
|
28
|
+
|
|
29
|
+
### start函数
|
|
30
|
+
|
|
31
|
+
**功能说明**
|
|
32
|
+
|
|
33
|
+
启动函数。
|
|
34
|
+
|
|
35
|
+
**原型**
|
|
36
|
+
|
|
37
|
+
```Python
|
|
38
|
+
debugger.start()
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
该函数为类函数,可以使用debugger.start()也可以使用PrecisionDebugger.start()。
|
|
42
|
+
|
|
43
|
+
## 示例代码
|
|
44
|
+
|
|
45
|
+
```Python
|
|
46
|
+
from msprobe.mindspore import PrecisionDebugger
|
|
47
|
+
debugger = PrecisionDebugger(config_path="./config.json")
|
|
48
|
+
# 请勿将以上初始化流程插入到循环代码中
|
|
49
|
+
# 下面代码也可以用PrecisionDebugger.start()
|
|
50
|
+
debugger.start()
|
|
51
|
+
...
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
## dump结果文件介绍
|
|
55
|
+
|
|
56
|
+
训练结束后,工具将dump的数据保存在dump_path参数指定的目录下。
|
|
57
|
+
|
|
58
|
+
- level为L1时
|
|
59
|
+
|
|
60
|
+
dump结果目录请参见MindSpore官网中的《[同步Dump数据对象目录](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.0rc2/debug/dump.html#%E5%90%8C%E6%AD%A5dump%E6%95%B0%E6%8D%AE%E5%AF%B9%E8%B1%A1%E7%9B%AE%E5%BD%95)》。
|
|
61
|
+
|
|
62
|
+
- level为L2时
|
|
63
|
+
|
|
64
|
+
dump结果目录请参见MindSpore官网中的《[异步Dump数据对象目录](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.0rc2/debug/dump.html#%E5%BC%82%E6%AD%A5dump%E6%95%B0%E6%8D%AE%E5%AF%B9%E8%B1%A1%E7%9B%AE%E5%BD%95)》。
|
|
65
|
+
|
|
File without changes
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
from msprobe.core.common.utils import make_dump_path_if_not_exists
|
|
4
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
5
|
+
from msprobe.core.common.log import logger
|
|
6
|
+
from msprobe.core.common.file_check import FileOpen
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ApiKbkDump:
|
|
10
|
+
def __init__(self, config: DebuggerConfig):
|
|
11
|
+
self.dump_json = dict()
|
|
12
|
+
self.dump_json["common_dump_settings"] = dict()
|
|
13
|
+
self.dump_json["common_dump_settings"]["dump_mode"] = 0
|
|
14
|
+
self.dump_json["common_dump_settings"]["path"] = ""
|
|
15
|
+
self.dump_json["common_dump_settings"]["net_name"] = "Net"
|
|
16
|
+
self.dump_json["common_dump_settings"]["iteration"] = "all"
|
|
17
|
+
self.dump_json["common_dump_settings"]["saved_data"] = "statistic"
|
|
18
|
+
self.dump_json["common_dump_settings"]["input_output"] = 0
|
|
19
|
+
self.dump_json["common_dump_settings"]["kernels"] = []
|
|
20
|
+
self.dump_json["common_dump_settings"]["support_device"] = [0,1,2,3,4,5,6,7]
|
|
21
|
+
self.dump_json["e2e_dump_settings"] = dict()
|
|
22
|
+
self.dump_json["e2e_dump_settings"]["enable"] = True
|
|
23
|
+
self.dump_json["e2e_dump_settings"]["trans_flag"] = True
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
if len(config.list) > 0:
|
|
27
|
+
self.dump_json["common_dump_settings"]["dump_mode"] = 1
|
|
28
|
+
self.dump_json["common_dump_settings"]["kernels"] = config.list
|
|
29
|
+
self.dump_json["common_dump_settings"]["path"] = config.dump_path
|
|
30
|
+
if len(config.step) > 0:
|
|
31
|
+
step_str = ""
|
|
32
|
+
for s in config.step:
|
|
33
|
+
step_str += (str(s) + '|')
|
|
34
|
+
self.dump_json["common_dump_settings"]["iteration"] = step_str[:-1]
|
|
35
|
+
if len(config.rank) > 0:
|
|
36
|
+
self.dump_json["common_dump_settings"]["support_device"] = config.rank
|
|
37
|
+
if config.task == "tensor":
|
|
38
|
+
self.dump_json["common_dump_settings"]["saved_data"] = "tensor"
|
|
39
|
+
if len(config.data_mode) == 1:
|
|
40
|
+
if config.data_mode[0] == "input":
|
|
41
|
+
self.dump_json["common_dump_settings"]["input_output"] = 1
|
|
42
|
+
if config.data_mode[0] == "output":
|
|
43
|
+
self.dump_json["common_dump_settings"]["input_output"] = 2
|
|
44
|
+
|
|
45
|
+
def handle(self):
|
|
46
|
+
json_path = self.dump_json["common_dump_settings"]["path"]
|
|
47
|
+
make_dump_path_if_not_exists(json_path)
|
|
48
|
+
json_path = os.path.join(json_path, "api_kbk_dump.json")
|
|
49
|
+
with FileOpen(json_path, 'w') as f:
|
|
50
|
+
json.dump(self.dump_json, f)
|
|
51
|
+
logger.info(json_path + " has been created.")
|
|
52
|
+
os.environ["GRAPH_OP_RUN"] = "1"
|
|
53
|
+
os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
|
|
54
|
+
if "MS_ACL_DUMP_CFG_PATH" in os.environ:
|
|
55
|
+
del os.environ["MS_ACL_DUMP_CFG_PATH"]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
2
|
+
from msprobe.mindspore.dump.api_kbk_dump import ApiKbkDump
|
|
3
|
+
from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DumpToolFactory:
|
|
7
|
+
tools = {
|
|
8
|
+
"cell": {
|
|
9
|
+
"kbk": None,
|
|
10
|
+
"graph": None,
|
|
11
|
+
"pynative": None
|
|
12
|
+
},
|
|
13
|
+
"api": {
|
|
14
|
+
"kbk": ApiKbkDump,
|
|
15
|
+
"graph": None,
|
|
16
|
+
"pynative": None
|
|
17
|
+
},
|
|
18
|
+
"kernel": {
|
|
19
|
+
"kbk": None,
|
|
20
|
+
"graph": KernelGraphDump,
|
|
21
|
+
"pynative": None
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def create(config: DebuggerConfig):
|
|
27
|
+
tool = DumpToolFactory.tools.get(config.level)
|
|
28
|
+
if not tool:
|
|
29
|
+
raise Exception("valid level is needed.")
|
|
30
|
+
if config.level == "api":
|
|
31
|
+
tool = tool.get("kbk")
|
|
32
|
+
elif config.level == "kernel":
|
|
33
|
+
tool = tool.get("graph")
|
|
34
|
+
elif config.level == "cell":
|
|
35
|
+
raise Exception("Cell dump in not supported now.")
|
|
36
|
+
if not tool:
|
|
37
|
+
raise Exception("Data dump in not supported in this mode.")
|
|
38
|
+
return tool(config)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
from msprobe.core.common.utils import make_dump_path_if_not_exists
|
|
4
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
5
|
+
from msprobe.core.common.log import logger
|
|
6
|
+
from msprobe.core.common.file_check import FileOpen
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class KernelGraphDump:
|
|
10
|
+
def __init__(self, config: DebuggerConfig):
|
|
11
|
+
self.dump_json = dict()
|
|
12
|
+
self.dump_json["common_dump_settings"] = dict()
|
|
13
|
+
self.dump_json["common_dump_settings"]["dump_mode"] = 0
|
|
14
|
+
self.dump_json["common_dump_settings"]["path"] = ""
|
|
15
|
+
self.dump_json["common_dump_settings"]["net_name"] = "Net"
|
|
16
|
+
self.dump_json["common_dump_settings"]["iteration"] = "all"
|
|
17
|
+
self.dump_json["common_dump_settings"]["saved_data"] = "statistic"
|
|
18
|
+
self.dump_json["common_dump_settings"]["input_output"] = 0
|
|
19
|
+
self.dump_json["common_dump_settings"]["kernels"] = []
|
|
20
|
+
self.dump_json["common_dump_settings"]["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7]
|
|
21
|
+
self.dump_json["common_dump_settings"]["op_debug_mode"] = 0
|
|
22
|
+
self.dump_json["common_dump_settings"]["file_format"] = "npy"
|
|
23
|
+
|
|
24
|
+
if len(config.list) > 0:
|
|
25
|
+
self.dump_json["common_dump_settings"]["dump_mode"] = 1
|
|
26
|
+
self.dump_json["common_dump_settings"]["kernels"] = config.list
|
|
27
|
+
self.dump_json["common_dump_settings"]["path"] = config.dump_path
|
|
28
|
+
if len(config.step) > 0:
|
|
29
|
+
step_str = ""
|
|
30
|
+
for s in config.step:
|
|
31
|
+
step_str += (str(s) + '|')
|
|
32
|
+
self.dump_json["common_dump_settings"]["iteration"] = step_str[:-1]
|
|
33
|
+
if len(config.rank) > 0:
|
|
34
|
+
self.dump_json["common_dump_settings"]["support_device"] = config.rank
|
|
35
|
+
if config.task == "tensor":
|
|
36
|
+
self.dump_json["common_dump_settings"]["saved_data"] = "tensor"
|
|
37
|
+
self.dump_json["common_dump_settings"]["file_format"] = config.file_format
|
|
38
|
+
if len(config.data_mode) == 1:
|
|
39
|
+
if config.data_mode[0] == "input":
|
|
40
|
+
self.dump_json["common_dump_settings"]["input_output"] = 1
|
|
41
|
+
if config.data_mode[0] == "output":
|
|
42
|
+
self.dump_json["common_dump_settings"]["input_output"] = 2
|
|
43
|
+
|
|
44
|
+
def handle(self):
|
|
45
|
+
if os.getenv("GRAPH_OP_RUN") == "1":
|
|
46
|
+
raise Exception("Must run in graph mode, not kbk mode")
|
|
47
|
+
json_path = self.dump_json["common_dump_settings"]["path"]
|
|
48
|
+
make_dump_path_if_not_exists(json_path)
|
|
49
|
+
json_path = os.path.join(json_path, "kernel_graph_dump.json")
|
|
50
|
+
with FileOpen(json_path, 'w') as f:
|
|
51
|
+
json.dump(self.dump_json, f)
|
|
52
|
+
logger.info(json_path + " has been created.")
|
|
53
|
+
os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
|
|
54
|
+
if self.dump_json["common_dump_settings"]["dump_mode"] == 0:
|
|
55
|
+
if self.dump_json["common_dump_settings"]["iteration"] != "all" or \
|
|
56
|
+
len(self.dump_json["common_dump_settings"]["kernels"]) == 0:
|
|
57
|
+
os.environ["MS_ACL_DUMP_CFG_PATH"] = json_path
|
|
58
|
+
else:
|
|
59
|
+
if "MS_ACL_DUMP_CFG_PATH" in os.environ:
|
|
60
|
+
del os.environ["MS_ACL_DUMP_CFG_PATH"]
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from msprobe.core.common_config import CommonConfig, BaseConfig
|
|
3
|
+
from msprobe.core.common.file_check import FileOpen
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TensorConfig(BaseConfig):
|
|
7
|
+
def __init__(self, json_config):
|
|
8
|
+
super().__init__(json_config)
|
|
9
|
+
self.check_mode = None
|
|
10
|
+
self.file_format = json_config.get("file_format")
|
|
11
|
+
self.check_config()
|
|
12
|
+
self._check_config()
|
|
13
|
+
|
|
14
|
+
def _check_config(self):
|
|
15
|
+
if self.data_mode is not None and len(self.data_mode) > 0:
|
|
16
|
+
if len(self.data_mode) > 1 or self.data_mode[0] not in ["all", "input", "output"]:
|
|
17
|
+
raise Exception("data_mode must be all, input or output")
|
|
18
|
+
if self.file_format and self.file_format not in ["npy", "bin"]:
|
|
19
|
+
raise Exception("file_format is invalid")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class StatisticsConfig(BaseConfig):
|
|
23
|
+
def __init__(self, json_config):
|
|
24
|
+
super().__init__(json_config)
|
|
25
|
+
self.file_format = None
|
|
26
|
+
self.check_mode = None
|
|
27
|
+
self.check_config()
|
|
28
|
+
self._check_config()
|
|
29
|
+
|
|
30
|
+
def _check_config(self):
|
|
31
|
+
if self.data_mode is not None and len(self.data_mode) > 0:
|
|
32
|
+
if len(self.data_mode) > 1 or self.data_mode[0] not in ["all", "input", "output"]:
|
|
33
|
+
raise Exception("data_mode must be all, input or output")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class OverflowCheck(BaseConfig):
|
|
37
|
+
def __init__(self, json_config):
|
|
38
|
+
super().__init__(json_config)
|
|
39
|
+
self.file_format = None
|
|
40
|
+
self.check_mode = json_config.get("check_mode")
|
|
41
|
+
self._check_config()
|
|
42
|
+
|
|
43
|
+
def _check_config(self):
|
|
44
|
+
if self.data_mode is not None and len(self.data_mode) > 0:
|
|
45
|
+
if len(self.data_mode) > 1 or self.data_mode[0] not in ["all", "input", "output"]:
|
|
46
|
+
raise Exception("data_mode must be all, input or output")
|
|
47
|
+
if self.check_mode and self.check_mode not in ["all", "aicore", "atomic"]:
|
|
48
|
+
raise Exception("check_mode is invalid")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def parse_common_config(json_config):
|
|
52
|
+
return CommonConfig(json_config)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def parse_task_config(task, json_config):
|
|
56
|
+
task_map = json_config[task]
|
|
57
|
+
if not task_map:
|
|
58
|
+
task_map = dict()
|
|
59
|
+
if task == "tensor":
|
|
60
|
+
return TensorConfig(task_map)
|
|
61
|
+
elif task == "statistics":
|
|
62
|
+
return StatisticsConfig(task_map)
|
|
63
|
+
elif task == "overflow_check":
|
|
64
|
+
return OverflowCheck(task_map)
|
|
65
|
+
else:
|
|
66
|
+
raise Exception("task is invalid.")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def parse_json_config(json_file_path):
|
|
70
|
+
if not json_file_path:
|
|
71
|
+
raise Exception("json file path is None")
|
|
72
|
+
with FileOpen(json_file_path, 'r') as file:
|
|
73
|
+
json_config = json.load(file)
|
|
74
|
+
common_config = parse_common_config(json_config)
|
|
75
|
+
if not common_config.task:
|
|
76
|
+
common_config.task = "statistics"
|
|
77
|
+
task_config = parse_task_config(common_config.task, json_config)
|
|
78
|
+
return common_config, task_config
|
|
File without changes
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
from msprobe.core.common.utils import make_dump_path_if_not_exists
|
|
4
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
5
|
+
from msprobe.core.common.log import logger
|
|
6
|
+
from msprobe.core.common.file_check import FileOpen
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class KernelGraphOverflowCheck:
|
|
10
|
+
def __init__(self, config: DebuggerConfig):
|
|
11
|
+
self.dump_json = dict()
|
|
12
|
+
self.dump_json["common_dump_settings"] = dict()
|
|
13
|
+
self.dump_json["common_dump_settings"]["dump_mode"] = 0
|
|
14
|
+
self.dump_json["common_dump_settings"]["path"] = ""
|
|
15
|
+
self.dump_json["common_dump_settings"]["net_name"] = "Net"
|
|
16
|
+
self.dump_json["common_dump_settings"]["iteration"] = "all"
|
|
17
|
+
self.dump_json["common_dump_settings"]["saved_data"] = "full"
|
|
18
|
+
self.dump_json["common_dump_settings"]["input_output"] = 0
|
|
19
|
+
self.dump_json["common_dump_settings"]["kernels"] = []
|
|
20
|
+
self.dump_json["common_dump_settings"]["support_device"] = [0,1,2,3,4,5,6,7]
|
|
21
|
+
self.dump_json["common_dump_settings"]["op_debug_mode"] = 3
|
|
22
|
+
self.dump_json["common_dump_settings"]["file_format"] = "npy"
|
|
23
|
+
|
|
24
|
+
self.dump_json["common_dump_settings"]["path"] = config.dump_path
|
|
25
|
+
if len(config.step) > 0:
|
|
26
|
+
logger.warning("Step would change to all in this task.")
|
|
27
|
+
if len(config.rank) > 0:
|
|
28
|
+
self.dump_json["common_dump_settings"]["support_device"] = config.rank
|
|
29
|
+
if config.check_mode == "aicore":
|
|
30
|
+
self.dump_json["common_dump_settings"]["op_debug_mode"] = 1
|
|
31
|
+
elif config.check_mode == "atomic":
|
|
32
|
+
self.dump_json["common_dump_settings"]["op_debug_mode"] = 2
|
|
33
|
+
|
|
34
|
+
def handle(self):
|
|
35
|
+
if os.getenv("GRAPH_OP_RUN") == "1":
|
|
36
|
+
raise Exception("Must run in graph mode, not kbk mode")
|
|
37
|
+
json_path = self.dump_json["common_dump_settings"]["path"]
|
|
38
|
+
make_dump_path_if_not_exists(json_path)
|
|
39
|
+
json_path = os.path.join(json_path, "kernel_graph_overflow_check.json")
|
|
40
|
+
with FileOpen(json_path, 'w') as f:
|
|
41
|
+
json.dump(self.dump_json, f)
|
|
42
|
+
logger.info(json_path + " has been created.")
|
|
43
|
+
os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
|
|
44
|
+
if "MS_ACL_DUMP_CFG_PATH" in os.environ:
|
|
45
|
+
del os.environ["MS_ACL_DUMP_CFG_PATH"]
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
2
|
+
from msprobe.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class OverflowCheckToolFactory:
|
|
6
|
+
tools = {
|
|
7
|
+
"cell": {
|
|
8
|
+
"kbk": None,
|
|
9
|
+
"graph": None,
|
|
10
|
+
"pynative": None
|
|
11
|
+
},
|
|
12
|
+
"api": {
|
|
13
|
+
"kbk": None,
|
|
14
|
+
"graph": None,
|
|
15
|
+
"pynative": None
|
|
16
|
+
},
|
|
17
|
+
"kernel": {
|
|
18
|
+
"kbk": None,
|
|
19
|
+
"graph": KernelGraphOverflowCheck,
|
|
20
|
+
"pynative": None
|
|
21
|
+
}
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def create(config: DebuggerConfig):
|
|
26
|
+
tool = OverflowCheckToolFactory.tools.get(config.level)
|
|
27
|
+
if not tool:
|
|
28
|
+
raise Exception("valid level is needed.")
|
|
29
|
+
tool = tool.get("graph")
|
|
30
|
+
if not tool:
|
|
31
|
+
raise Exception("Overflow check in not supported in this mode.")
|
|
32
|
+
return tool(config)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
2
|
+
from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory
|
|
3
|
+
from msprobe.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TaskHandlerFactory:
|
|
7
|
+
tasks = {
|
|
8
|
+
"tensor": DumpToolFactory,
|
|
9
|
+
"statistics": DumpToolFactory,
|
|
10
|
+
"overflow_check": OverflowCheckToolFactory
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
@staticmethod
|
|
14
|
+
def create(config: DebuggerConfig):
|
|
15
|
+
task = TaskHandlerFactory.tasks.get(config.task)
|
|
16
|
+
if not task:
|
|
17
|
+
raise Exception("valid task is needed.")
|
|
18
|
+
handler = task.create(config)
|
|
19
|
+
if not handler:
|
|
20
|
+
raise Exception("Can not find task handler")
|
|
21
|
+
return handler
|
msprobe/msprobe.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Copyright (c) 2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import sys
|
|
18
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
|
|
19
|
+
from msprobe.pytorch.parse_tool.cli import parse as cli_parse
|
|
20
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut
|
|
21
|
+
from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, \
|
|
22
|
+
_api_precision_compare_command
|
|
23
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
|
|
24
|
+
_run_overflow_check_command
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def main():
|
|
28
|
+
parser = argparse.ArgumentParser(
|
|
29
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
30
|
+
description="msprobe(mindstudio probe), [Powered by MindStudio].\n"
|
|
31
|
+
"Providing one-site accuracy difference debugging toolkit for training on Ascend Devices.\n"
|
|
32
|
+
f"For any issue, refer README.md first",
|
|
33
|
+
)
|
|
34
|
+
parser.set_defaults(print_help=parser.print_help)
|
|
35
|
+
parser.add_argument('-f', '--framework', required=True, choices=['pytorch'],
|
|
36
|
+
help='Deep learning framework.')
|
|
37
|
+
subparsers = parser.add_subparsers()
|
|
38
|
+
subparsers.add_parser('parse')
|
|
39
|
+
run_ut_cmd_parser = subparsers.add_parser('run_ut')
|
|
40
|
+
multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut')
|
|
41
|
+
api_precision_compare_cmd_parser = subparsers.add_parser('api_precision_compare')
|
|
42
|
+
run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check')
|
|
43
|
+
_run_ut_parser(run_ut_cmd_parser)
|
|
44
|
+
_run_ut_parser(multi_run_ut_cmd_parser)
|
|
45
|
+
multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
|
|
46
|
+
help='Number of splits for parallel processing. Range: 1-64')
|
|
47
|
+
_api_precision_compare_parser(api_precision_compare_cmd_parser)
|
|
48
|
+
_run_overflow_check_parser(run_overflow_check_cmd_parser)
|
|
49
|
+
if len(sys.argv) == 1:
|
|
50
|
+
parser.print_help()
|
|
51
|
+
sys.exit(0)
|
|
52
|
+
args = parser.parse_args(sys.argv[1:])
|
|
53
|
+
if sys.argv[3] == "run_ut":
|
|
54
|
+
run_ut_command(args)
|
|
55
|
+
elif sys.argv[3] == "parse":
|
|
56
|
+
cli_parse()
|
|
57
|
+
elif sys.argv[3] == "multi_run_ut":
|
|
58
|
+
config = prepare_config(args)
|
|
59
|
+
run_parallel_ut(config)
|
|
60
|
+
elif sys.argv[3] == "api_precision_compare":
|
|
61
|
+
_api_precision_compare_command(args)
|
|
62
|
+
elif sys.argv[3] == "run_overflow_check":
|
|
63
|
+
_run_overflow_check_command(args)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
if __name__ == "__main__":
|
|
67
|
+
main()
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import os
|
|
19
|
+
|
|
20
|
+
from msprobe.pytorch.advisor.advisor_result import AdvisorResult
|
|
21
|
+
from msprobe.pytorch.advisor.advisor_const import AdvisorConst
|
|
22
|
+
from msprobe.pytorch.common.log import logger
|
|
23
|
+
from msprobe.core.common.utils import CompareException
|
|
24
|
+
from msprobe.core.common.file_check import FileChecker
|
|
25
|
+
from msprobe.core.common.const import Const, CompareConst, FileCheckConst
|
|
26
|
+
|
|
27
|
+
class Advisor:
|
|
28
|
+
"""
|
|
29
|
+
Class for generate advisor
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, input_data, out_path=""):
|
|
33
|
+
self.input_data = input_data
|
|
34
|
+
self.out_path = os.path.realpath(out_path)
|
|
35
|
+
self.file_type = None
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def deterministic_advisor(message, node_name):
|
|
39
|
+
for api_name in AdvisorConst.NEED_DETERMINISTIC_API:
|
|
40
|
+
if api_name in node_name:
|
|
41
|
+
return AdvisorConst.DETERMINISTIC_SUGGEST
|
|
42
|
+
return message
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def batch_norm_advisor(message, node_name):
|
|
46
|
+
if AdvisorConst.FUNC_BATCH_NORM in node_name and AdvisorConst.FORWARD_INPUT_1 in node_name:
|
|
47
|
+
message = AdvisorConst.BATCH_NORM_SUGGEST
|
|
48
|
+
return message
|
|
49
|
+
|
|
50
|
+
def analyze_unmatched(self, analyze_data):
|
|
51
|
+
if self.file_type == Const.ALL:
|
|
52
|
+
accuracy_unmatched = analyze_data[
|
|
53
|
+
analyze_data[CompareConst.ACCURACY] == CompareConst.ACCURACY_CHECK_UNMATCH]
|
|
54
|
+
else:
|
|
55
|
+
accuracy_unmatched = analyze_data[(analyze_data[CompareConst.NPU_SHAPE] == CompareConst.NAN) |
|
|
56
|
+
(analyze_data[CompareConst.BENCH_SHAPE] == CompareConst.NAN)]
|
|
57
|
+
num_unmatch = len(accuracy_unmatched)
|
|
58
|
+
if num_unmatch != 0:
|
|
59
|
+
for i in range(len(accuracy_unmatched)):
|
|
60
|
+
item = accuracy_unmatched.iloc[i]
|
|
61
|
+
logger.warning("The tensor name matches but the shape or dtype does not match: {}"
|
|
62
|
+
.format(item[CompareConst.NPU_NAME]))
|
|
63
|
+
|
|
64
|
+
def gen_advisor_result(self, pd_data):
|
|
65
|
+
first_failing_data = pd_data.iloc[0]
|
|
66
|
+
node_name = first_failing_data[CompareConst.NPU_NAME]
|
|
67
|
+
index = first_failing_data['index']
|
|
68
|
+
message = self.gen_advisor_message(node_name)
|
|
69
|
+
logger.warning("Find %s accuracy not reached, the line is %s" % (node_name, index))
|
|
70
|
+
result = AdvisorResult(node_name, index, message)
|
|
71
|
+
return result
|
|
72
|
+
|
|
73
|
+
def gen_advisor_message(self, node_name):
|
|
74
|
+
if AdvisorConst.FORWARD in node_name:
|
|
75
|
+
if AdvisorConst.INPUT in node_name:
|
|
76
|
+
message = AdvisorConst.FORWARD_INPUT_SUGGEST
|
|
77
|
+
else:
|
|
78
|
+
message = AdvisorConst.FORWARD_OUTPUT_SUGGEST
|
|
79
|
+
message = self.deterministic_advisor(message, node_name)
|
|
80
|
+
else:
|
|
81
|
+
if AdvisorConst.INPUT in node_name:
|
|
82
|
+
message = AdvisorConst.BACKWARD_INPUT_SUGGEST
|
|
83
|
+
else:
|
|
84
|
+
message = AdvisorConst.BACKWARD_OUTPUT_SUGGEST
|
|
85
|
+
message = self.deterministic_advisor(message, node_name)
|
|
86
|
+
message = self.batch_norm_advisor(message, node_name)
|
|
87
|
+
return message
|
|
88
|
+
|
|
89
|
+
def analysis(self):
|
|
90
|
+
self._check_path_vaild()
|
|
91
|
+
analyze_data = self._parse_input_data()
|
|
92
|
+
logger.info("Start analyzing the comparison result: %s" % self.file_type)
|
|
93
|
+
self.analyze_unmatched(analyze_data)
|
|
94
|
+
if self.file_type == Const.ALL:
|
|
95
|
+
failing_data = analyze_data[analyze_data[CompareConst.ACCURACY] == CompareConst.ACCURACY_CHECK_NO]
|
|
96
|
+
elif self.file_type == Const.MD5:
|
|
97
|
+
failing_data = analyze_data[analyze_data[CompareConst.RESULT] == CompareConst.DIFF]
|
|
98
|
+
elif self.file_type == Const.SUMMARY:
|
|
99
|
+
failing_data = analyze_data[analyze_data[CompareConst.RESULT] == CompareConst.WARNING]
|
|
100
|
+
if failing_data.empty:
|
|
101
|
+
logger.info("All data from api input/output accuracy reached")
|
|
102
|
+
result = AdvisorResult(AdvisorConst.NO_ERROR_API, AdvisorConst.NO_ERROR_API, AdvisorConst.NO_ERR_SUGGEST)
|
|
103
|
+
else:
|
|
104
|
+
result = self.gen_advisor_result(failing_data)
|
|
105
|
+
message_list = result.print_advisor_log()
|
|
106
|
+
result.gen_summary_file(self.out_path, message_list)
|
|
107
|
+
|
|
108
|
+
def _parse_input_data(self):
|
|
109
|
+
data_columns = self.input_data.columns.values
|
|
110
|
+
if {CompareConst.ACCURACY, CompareConst.NPU_NAME}.issubset(data_columns):
|
|
111
|
+
self.file_type = Const.ALL
|
|
112
|
+
elif {CompareConst.RESULT, CompareConst.NPU_MD5}.issubset(data_columns):
|
|
113
|
+
self.file_type = Const.MD5
|
|
114
|
+
elif {CompareConst.MAX_DIFF, CompareConst.RESULT}.issubset(data_columns):
|
|
115
|
+
self.file_type = Const.SUMMARY
|
|
116
|
+
else:
|
|
117
|
+
logger.error('Compare result does not meet the required conditions.')
|
|
118
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
119
|
+
df = self.input_data.reset_index()
|
|
120
|
+
return df
|
|
121
|
+
|
|
122
|
+
def _check_path_vaild(self):
|
|
123
|
+
out_path_checker = FileChecker(self.out_path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE)
|
|
124
|
+
out_path_checker.common_check()
|