mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
- mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +101 -237
- msprobe/{config/config.json → config.json} +49 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +59 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +341 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +478 -283
- msprobe/core/common/log.py +76 -69
- msprobe/core/common/utils.py +385 -616
- msprobe/core/common_config.py +85 -71
- msprobe/core/compare/acc_compare.py +299 -298
- msprobe/core/compare/check.py +95 -95
- msprobe/core/compare/compare_cli.py +49 -49
- msprobe/core/compare/highlight.py +223 -222
- msprobe/core/compare/multiprocessing_compute.py +149 -149
- msprobe/core/compare/npy_compare.py +295 -295
- msprobe/core/compare/utils.py +430 -429
- msprobe/core/data_dump/data_collector.py +154 -144
- msprobe/core/data_dump/data_processor/base.py +314 -293
- msprobe/core/data_dump/data_processor/factory.py +59 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
- msprobe/core/data_dump/json_writer.py +96 -116
- msprobe/core/data_dump/scope.py +178 -178
- msprobe/core/grad_probe/constant.py +70 -70
- msprobe/core/grad_probe/grad_compare.py +171 -175
- msprobe/core/grad_probe/utils.py +64 -52
- msprobe/docs/01.installation.md +89 -0
- msprobe/docs/02.config_introduction.md +165 -0
- msprobe/docs/03.config_examples.md +247 -0
- msprobe/docs/04.acl_config_examples.md +76 -0
- msprobe/docs/05.data_dump_PyTorch.md +198 -0
- msprobe/docs/06.data_dump_MindSpore.md +243 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
- msprobe/docs/FAQ_PyTorch.md +177 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +34 -34
- msprobe/mindspore/common/const.py +106 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +81 -57
- msprobe/mindspore/compare/distributed_compare.py +75 -75
- msprobe/mindspore/compare/ms_compare.py +219 -117
- msprobe/mindspore/compare/ms_graph_compare.py +348 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +66 -74
- msprobe/mindspore/debugger/precision_debugger.py +126 -107
- msprobe/mindspore/dump/dump_tool_factory.py +35 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
- msprobe/mindspore/dump/jit_dump.py +72 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
- msprobe/mindspore/free_benchmark/common/config.py +12 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
- msprobe/mindspore/free_benchmark/common/utils.py +71 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
- msprobe/mindspore/grad_probe/global_context.py +90 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +378 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +3 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
- msprobe/pytorch/bench_functions/__init__.py +15 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
- msprobe/pytorch/bench_functions/linear.py +12 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
- msprobe/pytorch/bench_functions/rms_norm.py +15 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
- msprobe/pytorch/bench_functions/swiglu.py +55 -55
- msprobe/pytorch/common/__init__.py +2 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +20 -31
- msprobe/pytorch/common/parse_json.py +39 -39
- msprobe/pytorch/common/utils.py +305 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +34 -33
- msprobe/pytorch/compare/pt_compare.py +50 -40
- msprobe/pytorch/debugger/debugger_config.py +95 -95
- msprobe/pytorch/debugger/precision_debugger.py +125 -125
- msprobe/pytorch/free_benchmark/__init__.py +8 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +37 -37
- msprobe/pytorch/free_benchmark/common/params.py +129 -129
- msprobe/pytorch/free_benchmark/common/utils.py +102 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
- msprobe/pytorch/free_benchmark/main.py +105 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
- msprobe/pytorch/function_factory.py +76 -75
- msprobe/pytorch/functional/dump_module.py +39 -39
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/api_registry.py +161 -161
- msprobe/pytorch/hook_module/hook_module.py +120 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +30 -29
- msprobe/pytorch/hook_module/wrap_aten.py +110 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
- msprobe/pytorch/hook_module/wrap_functional.py +105 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
- msprobe/pytorch/hook_module/wrap_torch.py +86 -86
- msprobe/pytorch/hook_module/wrap_vf.py +62 -62
- msprobe/pytorch/module_processer.py +138 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +236 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
- msprobe/pytorch/online_dispatch/utils.py +130 -146
- msprobe/pytorch/parse.py +4 -4
- msprobe/pytorch/parse_tool/cli.py +32 -32
- msprobe/pytorch/parse_tool/lib/compare.py +260 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
- msprobe/pytorch/parse_tool/lib/utils.py +316 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +188 -187
- msprobe/pytorch/service.py +246 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -1,293 +1,314 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import inspect
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
from typing import Tuple, Dict, Optional, Any
|
|
5
|
-
import numpy as np
|
|
6
|
-
from msprobe.core.common.log import logger
|
|
7
|
-
from msprobe.core.common.utils import convert_tuple
|
|
8
|
-
from msprobe.core.common.const import Const
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
@dataclass
|
|
12
|
-
class ModuleForwardInputsOutputs:
|
|
13
|
-
args: Optional[Tuple]
|
|
14
|
-
kwargs: Optional[Dict]
|
|
15
|
-
output: Any
|
|
16
|
-
|
|
17
|
-
@property
|
|
18
|
-
def args_tuple(self):
|
|
19
|
-
return convert_tuple(self.args)
|
|
20
|
-
|
|
21
|
-
@property
|
|
22
|
-
def output_tuple(self):
|
|
23
|
-
return convert_tuple(self.output)
|
|
24
|
-
|
|
25
|
-
def concat_args_and_kwargs(self):
|
|
26
|
-
args = self.args + tuple(self.kwargs.values())
|
|
27
|
-
return args
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
@dataclass
|
|
31
|
-
class ModuleBackwardInputsOutputs:
|
|
32
|
-
grad_output: Optional[Tuple]
|
|
33
|
-
grad_input: Optional[Tuple]
|
|
34
|
-
|
|
35
|
-
@property
|
|
36
|
-
def grad_input_tuple(self):
|
|
37
|
-
return convert_tuple(self.grad_input)
|
|
38
|
-
|
|
39
|
-
@property
|
|
40
|
-
def grad_output_tuple(self):
|
|
41
|
-
return convert_tuple(self.grad_output)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
@dataclass
|
|
45
|
-
class ModuleBackwardInputs:
|
|
46
|
-
grad_input: Optional[Tuple]
|
|
47
|
-
|
|
48
|
-
@property
|
|
49
|
-
def grad_input_tuple(self):
|
|
50
|
-
return convert_tuple(self.grad_input)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
@dataclass
|
|
54
|
-
class ModuleBackwardOutputs:
|
|
55
|
-
grad_output: Optional[Tuple]
|
|
56
|
-
|
|
57
|
-
@property
|
|
58
|
-
def grad_output_tuple(self):
|
|
59
|
-
return convert_tuple(self.grad_output)
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
class TensorStatInfo:
|
|
63
|
-
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
|
|
64
|
-
self.max = max_val
|
|
65
|
-
self.min = min_val
|
|
66
|
-
self.mean = mean_val
|
|
67
|
-
self.norm = norm_val
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
class BaseDataProcessor:
|
|
71
|
-
_recursive_key_stack = []
|
|
72
|
-
special_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
|
|
73
|
-
bool, int, float, str, slice)
|
|
74
|
-
|
|
75
|
-
def __init__(self, config, data_writer):
|
|
76
|
-
self.data_writer = data_writer
|
|
77
|
-
self.config = config
|
|
78
|
-
self.api_info_struct = {}
|
|
79
|
-
self.stack_info_struct = {}
|
|
80
|
-
self.current_api_or_module_name = None
|
|
81
|
-
self.api_data_category = None
|
|
82
|
-
self.
|
|
83
|
-
self.
|
|
84
|
-
self.
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
" ".join(["
|
|
105
|
-
" ".join(["
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
stack_info_struct
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
np.
|
|
127
|
-
np.
|
|
128
|
-
np.
|
|
129
|
-
np.
|
|
130
|
-
np.
|
|
131
|
-
np.
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
def
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
api_info_struct[name]
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
1
|
+
import os
|
|
2
|
+
import inspect
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Tuple, Dict, Optional, Any
|
|
5
|
+
import numpy as np
|
|
6
|
+
from msprobe.core.common.log import logger
|
|
7
|
+
from msprobe.core.common.utils import convert_tuple
|
|
8
|
+
from msprobe.core.common.const import Const
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class ModuleForwardInputsOutputs:
|
|
13
|
+
args: Optional[Tuple]
|
|
14
|
+
kwargs: Optional[Dict]
|
|
15
|
+
output: Any
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def args_tuple(self):
|
|
19
|
+
return convert_tuple(self.args)
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def output_tuple(self):
|
|
23
|
+
return convert_tuple(self.output)
|
|
24
|
+
|
|
25
|
+
def concat_args_and_kwargs(self):
|
|
26
|
+
args = self.args + tuple(self.kwargs.values())
|
|
27
|
+
return args
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class ModuleBackwardInputsOutputs:
|
|
32
|
+
grad_output: Optional[Tuple]
|
|
33
|
+
grad_input: Optional[Tuple]
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def grad_input_tuple(self):
|
|
37
|
+
return convert_tuple(self.grad_input)
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def grad_output_tuple(self):
|
|
41
|
+
return convert_tuple(self.grad_output)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class ModuleBackwardInputs:
|
|
46
|
+
grad_input: Optional[Tuple]
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def grad_input_tuple(self):
|
|
50
|
+
return convert_tuple(self.grad_input)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class ModuleBackwardOutputs:
|
|
55
|
+
grad_output: Optional[Tuple]
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def grad_output_tuple(self):
|
|
59
|
+
return convert_tuple(self.grad_output)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class TensorStatInfo:
|
|
63
|
+
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
|
|
64
|
+
self.max = max_val
|
|
65
|
+
self.min = min_val
|
|
66
|
+
self.mean = mean_val
|
|
67
|
+
self.norm = norm_val
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class BaseDataProcessor:
|
|
71
|
+
_recursive_key_stack = []
|
|
72
|
+
special_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
|
|
73
|
+
bool, int, float, str, slice, type(Ellipsis))
|
|
74
|
+
|
|
75
|
+
def __init__(self, config, data_writer):
|
|
76
|
+
self.data_writer = data_writer
|
|
77
|
+
self.config = config
|
|
78
|
+
self.api_info_struct = {}
|
|
79
|
+
self.stack_info_struct = {}
|
|
80
|
+
self.current_api_or_module_name = None
|
|
81
|
+
self.api_data_category = None
|
|
82
|
+
self.current_iter = 0
|
|
83
|
+
self._return_forward_new_output = False
|
|
84
|
+
self._forward_new_output = None
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def data_path(self):
|
|
88
|
+
return self.data_writer.dump_tensor_data_dir
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def is_terminated(self):
|
|
92
|
+
return False
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def analyze_api_call_stack(name):
|
|
96
|
+
stack_str = []
|
|
97
|
+
for (_, path, line, func, code, _) in inspect.stack()[5:]:
|
|
98
|
+
if not code:
|
|
99
|
+
continue
|
|
100
|
+
stack_line = " ".join([
|
|
101
|
+
"File", ", ".join([
|
|
102
|
+
path,
|
|
103
|
+
" ".join(["line", str(line)]),
|
|
104
|
+
" ".join(["in", func]),
|
|
105
|
+
" ".join(["\n", code[0].strip()])
|
|
106
|
+
])
|
|
107
|
+
])
|
|
108
|
+
stack_str.append(stack_line)
|
|
109
|
+
stack_info_struct = {name: stack_str}
|
|
110
|
+
return stack_info_struct
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def transfer_type(data):
|
|
114
|
+
dtype = str(type(data))
|
|
115
|
+
if 'int' in dtype:
|
|
116
|
+
return int(data)
|
|
117
|
+
elif 'float' in dtype:
|
|
118
|
+
return float(data)
|
|
119
|
+
else:
|
|
120
|
+
return data
|
|
121
|
+
|
|
122
|
+
@staticmethod
|
|
123
|
+
def _convert_numpy_to_builtin(arg):
|
|
124
|
+
type_mapping = {
|
|
125
|
+
np.integer: int,
|
|
126
|
+
np.floating: float,
|
|
127
|
+
np.bool_: bool,
|
|
128
|
+
np.complexfloating: complex,
|
|
129
|
+
np.str_: str,
|
|
130
|
+
np.byte: bytes,
|
|
131
|
+
np.unicode_: str
|
|
132
|
+
}
|
|
133
|
+
for numpy_type, builtin_type in type_mapping.items():
|
|
134
|
+
if isinstance(arg, numpy_type):
|
|
135
|
+
return builtin_type(arg), type(arg).__name__
|
|
136
|
+
return arg, ''
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def _analyze_builtin(arg):
|
|
140
|
+
single_arg = {}
|
|
141
|
+
if isinstance(arg, slice):
|
|
142
|
+
# The slice parameter may be of the tensor, numpy or other types.
|
|
143
|
+
# It needs to be converted to the Python value type before JSON serialization
|
|
144
|
+
single_arg.update({"type": "slice"})
|
|
145
|
+
values = []
|
|
146
|
+
for value in [arg.start, arg.stop, arg.step]:
|
|
147
|
+
if value is not None:
|
|
148
|
+
try:
|
|
149
|
+
value = int(value)
|
|
150
|
+
except ValueError:
|
|
151
|
+
logger.warning(f"The data type {type(value)} cannot be converted to int type.")
|
|
152
|
+
value = None
|
|
153
|
+
values.append(value)
|
|
154
|
+
single_arg.update({"value": values})
|
|
155
|
+
else:
|
|
156
|
+
single_arg.update({"type": type(arg).__name__})
|
|
157
|
+
# When arg is Ellipsis(...) type, it needs to be converted to str("...") type
|
|
158
|
+
single_arg.update({"value": arg if arg is not Ellipsis else "..."})
|
|
159
|
+
return single_arg
|
|
160
|
+
|
|
161
|
+
@staticmethod
|
|
162
|
+
def _analyze_numpy(value, numpy_type):
|
|
163
|
+
return {"type": numpy_type, "value": value}
|
|
164
|
+
|
|
165
|
+
@classmethod
|
|
166
|
+
def get_special_types(cls):
|
|
167
|
+
return cls.special_type
|
|
168
|
+
|
|
169
|
+
@classmethod
|
|
170
|
+
def recursive_apply_transform(cls, args, transform):
|
|
171
|
+
if isinstance(args, cls.get_special_types()):
|
|
172
|
+
arg_transform = transform(args, cls._recursive_key_stack)
|
|
173
|
+
return arg_transform
|
|
174
|
+
elif isinstance(args, (list, tuple)):
|
|
175
|
+
result_list = []
|
|
176
|
+
for i, arg in enumerate(args):
|
|
177
|
+
cls._recursive_key_stack.append(str(i))
|
|
178
|
+
result_list.append(cls.recursive_apply_transform(arg, transform))
|
|
179
|
+
cls._recursive_key_stack.pop()
|
|
180
|
+
return type(args)(result_list)
|
|
181
|
+
elif isinstance(args, dict):
|
|
182
|
+
result_dict = {}
|
|
183
|
+
for k, arg in args.items():
|
|
184
|
+
cls._recursive_key_stack.append(str(k))
|
|
185
|
+
result_dict[k] = cls.recursive_apply_transform(arg, transform)
|
|
186
|
+
cls._recursive_key_stack.pop()
|
|
187
|
+
return result_dict
|
|
188
|
+
elif args is not None:
|
|
189
|
+
logger.warning(f"Data type {type(args)} is not supported.")
|
|
190
|
+
return None
|
|
191
|
+
else:
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
def if_return_forward_new_output(self):
|
|
195
|
+
return self._return_forward_new_output
|
|
196
|
+
|
|
197
|
+
def get_forward_new_output(self):
|
|
198
|
+
self._return_forward_new_output = False
|
|
199
|
+
return self._forward_new_output
|
|
200
|
+
|
|
201
|
+
def update_iter(self, current_iter):
|
|
202
|
+
self.current_iter = current_iter
|
|
203
|
+
|
|
204
|
+
def update_api_or_module_name(self, api_or_module_name):
|
|
205
|
+
if self.current_api_or_module_name != api_or_module_name:
|
|
206
|
+
self.current_api_or_module_name = api_or_module_name
|
|
207
|
+
|
|
208
|
+
def is_dump_for_data_mode(self, forward_backward, input_output):
|
|
209
|
+
"""
|
|
210
|
+
Compare the parameters with data_mode to determine whether to dump.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
forward_backward(str): The forward or backward mode to check.
|
|
214
|
+
input_output(str): The input or output mode to check.
|
|
215
|
+
|
|
216
|
+
Return:
|
|
217
|
+
bool: True if the parameters are in data_mode or data_mode is all, False otherwise.
|
|
218
|
+
"""
|
|
219
|
+
return (Const.ALL in self.config.data_mode or
|
|
220
|
+
forward_backward in self.config.data_mode or
|
|
221
|
+
input_output in self.config.data_mode)
|
|
222
|
+
|
|
223
|
+
def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
224
|
+
pass
|
|
225
|
+
|
|
226
|
+
def analyze_element(self, element):
|
|
227
|
+
return self.recursive_apply_transform(element, self.analyze_single_element)
|
|
228
|
+
|
|
229
|
+
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
230
|
+
api_info_struct = {}
|
|
231
|
+
# check whether data_mode contains forward or input
|
|
232
|
+
if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
|
|
233
|
+
api_info_struct[name] = {}
|
|
234
|
+
self.api_data_category = Const.INPUT
|
|
235
|
+
args_info_list = self.analyze_element(module_input_output.args_tuple)
|
|
236
|
+
api_info_struct[name][Const.INPUT_ARGS] = args_info_list
|
|
237
|
+
self.api_data_category = Const.KWARGS
|
|
238
|
+
kwargs_info_list = self.analyze_element(module_input_output.kwargs)
|
|
239
|
+
api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
|
|
240
|
+
|
|
241
|
+
# check whether data_mode contains forward or output
|
|
242
|
+
if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
|
|
243
|
+
api_info_struct[name] = api_info_struct.get(name, {})
|
|
244
|
+
self.api_data_category = Const.OUTPUT
|
|
245
|
+
output_info_list = self.analyze_element(module_input_output.output_tuple)
|
|
246
|
+
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
247
|
+
return api_info_struct
|
|
248
|
+
|
|
249
|
+
def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
|
|
250
|
+
api_info_struct = {}
|
|
251
|
+
if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
|
|
252
|
+
api_info_struct[name] = {}
|
|
253
|
+
self.api_data_category = Const.INPUT
|
|
254
|
+
args_info_list = self.analyze_element(module_input_output.args_tuple)
|
|
255
|
+
api_info_struct[name][Const.INPUT_ARGS] = args_info_list
|
|
256
|
+
self.api_data_category = Const.KWARGS
|
|
257
|
+
kwargs_info_list = self.analyze_element(module_input_output.kwargs)
|
|
258
|
+
api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
|
|
259
|
+
return api_info_struct
|
|
260
|
+
|
|
261
|
+
def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
|
|
262
|
+
concat_args = module_input_output.concat_args_and_kwargs()
|
|
263
|
+
api_info_struct = {}
|
|
264
|
+
if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
|
|
265
|
+
api_info_struct[name] = {}
|
|
266
|
+
self.api_data_category = Const.OUTPUT
|
|
267
|
+
output_info_list = self.analyze_element(concat_args)
|
|
268
|
+
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
269
|
+
return api_info_struct
|
|
270
|
+
|
|
271
|
+
def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
|
|
272
|
+
api_info_struct = {}
|
|
273
|
+
if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
|
|
274
|
+
api_info_struct[name] = {}
|
|
275
|
+
self.api_data_category = Const.INPUT
|
|
276
|
+
input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
|
|
277
|
+
api_info_struct[name][Const.INPUT] = input_info_list
|
|
278
|
+
|
|
279
|
+
if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
|
|
280
|
+
api_info_struct[name] = api_info_struct.get(name, {})
|
|
281
|
+
self.api_data_category = Const.OUTPUT
|
|
282
|
+
output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
|
|
283
|
+
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
284
|
+
|
|
285
|
+
return api_info_struct
|
|
286
|
+
|
|
287
|
+
def analyze_backward_input(self, name, module,
|
|
288
|
+
module_input_output: ModuleBackwardInputs):
|
|
289
|
+
api_info_struct = {}
|
|
290
|
+
if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
|
|
291
|
+
api_info_struct[name] = {}
|
|
292
|
+
self.api_data_category = Const.INPUT
|
|
293
|
+
|
|
294
|
+
input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
|
|
295
|
+
api_info_struct[name][Const.INPUT] = input_info_list
|
|
296
|
+
return api_info_struct
|
|
297
|
+
|
|
298
|
+
def analyze_backward_output(self, name, module,
|
|
299
|
+
module_input_output: ModuleBackwardOutputs):
|
|
300
|
+
api_info_struct = {}
|
|
301
|
+
if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
|
|
302
|
+
api_info_struct[name] = {}
|
|
303
|
+
self.api_data_category = Const.OUTPUT
|
|
304
|
+
|
|
305
|
+
output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
|
|
306
|
+
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
307
|
+
return api_info_struct
|
|
308
|
+
|
|
309
|
+
def get_save_file_path(self, suffix):
|
|
310
|
+
file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
|
|
311
|
+
dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
|
|
312
|
+
suffix + file_format)
|
|
313
|
+
file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
|
|
314
|
+
return dump_data_name, file_path
|