mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
- mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +101 -237
- msprobe/{config/config.json → config.json} +49 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +59 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +341 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +478 -283
- msprobe/core/common/log.py +76 -69
- msprobe/core/common/utils.py +385 -616
- msprobe/core/common_config.py +85 -71
- msprobe/core/compare/acc_compare.py +299 -298
- msprobe/core/compare/check.py +95 -95
- msprobe/core/compare/compare_cli.py +49 -49
- msprobe/core/compare/highlight.py +223 -222
- msprobe/core/compare/multiprocessing_compute.py +149 -149
- msprobe/core/compare/npy_compare.py +295 -295
- msprobe/core/compare/utils.py +430 -429
- msprobe/core/data_dump/data_collector.py +154 -144
- msprobe/core/data_dump/data_processor/base.py +314 -293
- msprobe/core/data_dump/data_processor/factory.py +59 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
- msprobe/core/data_dump/json_writer.py +96 -116
- msprobe/core/data_dump/scope.py +178 -178
- msprobe/core/grad_probe/constant.py +70 -70
- msprobe/core/grad_probe/grad_compare.py +171 -175
- msprobe/core/grad_probe/utils.py +64 -52
- msprobe/docs/01.installation.md +89 -0
- msprobe/docs/02.config_introduction.md +165 -0
- msprobe/docs/03.config_examples.md +247 -0
- msprobe/docs/04.acl_config_examples.md +76 -0
- msprobe/docs/05.data_dump_PyTorch.md +198 -0
- msprobe/docs/06.data_dump_MindSpore.md +243 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
- msprobe/docs/FAQ_PyTorch.md +177 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +34 -34
- msprobe/mindspore/common/const.py +106 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +81 -57
- msprobe/mindspore/compare/distributed_compare.py +75 -75
- msprobe/mindspore/compare/ms_compare.py +219 -117
- msprobe/mindspore/compare/ms_graph_compare.py +348 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +66 -74
- msprobe/mindspore/debugger/precision_debugger.py +126 -107
- msprobe/mindspore/dump/dump_tool_factory.py +35 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
- msprobe/mindspore/dump/jit_dump.py +72 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
- msprobe/mindspore/free_benchmark/common/config.py +12 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
- msprobe/mindspore/free_benchmark/common/utils.py +71 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
- msprobe/mindspore/grad_probe/global_context.py +90 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +378 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +3 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
- msprobe/pytorch/bench_functions/__init__.py +15 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
- msprobe/pytorch/bench_functions/linear.py +12 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
- msprobe/pytorch/bench_functions/rms_norm.py +15 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
- msprobe/pytorch/bench_functions/swiglu.py +55 -55
- msprobe/pytorch/common/__init__.py +2 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +20 -31
- msprobe/pytorch/common/parse_json.py +39 -39
- msprobe/pytorch/common/utils.py +305 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +34 -33
- msprobe/pytorch/compare/pt_compare.py +50 -40
- msprobe/pytorch/debugger/debugger_config.py +95 -95
- msprobe/pytorch/debugger/precision_debugger.py +125 -125
- msprobe/pytorch/free_benchmark/__init__.py +8 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +37 -37
- msprobe/pytorch/free_benchmark/common/params.py +129 -129
- msprobe/pytorch/free_benchmark/common/utils.py +102 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
- msprobe/pytorch/free_benchmark/main.py +105 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
- msprobe/pytorch/function_factory.py +76 -75
- msprobe/pytorch/functional/dump_module.py +39 -39
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/api_registry.py +161 -161
- msprobe/pytorch/hook_module/hook_module.py +120 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +30 -29
- msprobe/pytorch/hook_module/wrap_aten.py +110 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
- msprobe/pytorch/hook_module/wrap_functional.py +105 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
- msprobe/pytorch/hook_module/wrap_torch.py +86 -86
- msprobe/pytorch/hook_module/wrap_vf.py +62 -62
- msprobe/pytorch/module_processer.py +138 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +236 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
- msprobe/pytorch/online_dispatch/utils.py +130 -146
- msprobe/pytorch/parse.py +4 -4
- msprobe/pytorch/parse_tool/cli.py +32 -32
- msprobe/pytorch/parse_tool/lib/compare.py +260 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
- msprobe/pytorch/parse_tool/lib/utils.py +316 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +188 -187
- msprobe/pytorch/service.py +246 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -1,74 +1,70 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import re
|
|
3
|
-
|
|
4
|
-
from msprobe.core.common.const import FileCheckConst
|
|
5
|
-
from msprobe.core.common.
|
|
6
|
-
from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
|
|
7
|
-
from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
|
|
8
|
-
from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
|
|
9
|
-
from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
|
|
10
|
-
from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
|
|
11
|
-
|
|
12
|
-
hf_32_standard_api = ["conv1d", "conv2d"]
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class Backward_Message:
|
|
16
|
-
MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
|
|
17
|
-
UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, skip backward."
|
|
18
|
-
NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class UtDataInfo:
|
|
22
|
-
def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list,
|
|
23
|
-
backward_message, rank=0):
|
|
24
|
-
self.bench_grad = bench_grad
|
|
25
|
-
self.device_grad = device_grad
|
|
26
|
-
self.device_output = device_output
|
|
27
|
-
self.bench_output = bench_output
|
|
28
|
-
self.grad_in = grad_in
|
|
29
|
-
self.in_fwd_data_list = in_fwd_data_list
|
|
30
|
-
self.backward_message = backward_message
|
|
31
|
-
self.rank = rank
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def get_validated_result_csv_path(result_csv_path, mode):
|
|
35
|
-
if mode not in ['result', 'detail']:
|
|
36
|
-
raise ValueError("The csv mode must be result or detail")
|
|
37
|
-
result_csv_path_checker = FileChecker(result_csv_path, FileCheckConst.FILE, ability=FileCheckConst.READ_WRITE_ABLE,
|
|
38
|
-
file_type=FileCheckConst.CSV_SUFFIX)
|
|
39
|
-
validated_result_csv_path = result_csv_path_checker.common_check()
|
|
40
|
-
if mode == 'result':
|
|
41
|
-
result_csv_name = os.path.basename(validated_result_csv_path)
|
|
42
|
-
pattern = r"^accuracy_checking_result_\d{14}\.csv$"
|
|
43
|
-
if not re.match(pattern, result_csv_name):
|
|
44
|
-
raise ValueError("When continue run ut, please do not modify the result csv name.")
|
|
45
|
-
return validated_result_csv_path
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def get_validated_details_csv_path(validated_result_csv_path):
|
|
49
|
-
result_csv_name = os.path.basename(validated_result_csv_path)
|
|
50
|
-
details_csv_name = result_csv_name.replace('result', 'details')
|
|
51
|
-
details_csv_path = os.path.join(os.path.dirname(validated_result_csv_path), details_csv_name)
|
|
52
|
-
details_csv_path_checker = FileChecker(details_csv_path, FileCheckConst.FILE,
|
|
53
|
-
ability=FileCheckConst.READ_WRITE_ABLE, file_type=FileCheckConst.CSV_SUFFIX)
|
|
54
|
-
validated_details_csv_path = details_csv_path_checker.common_check()
|
|
55
|
-
return validated_details_csv_path
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def exec_api(api_type, api_name, args, kwargs):
|
|
59
|
-
if api_type == "Functional":
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
if api_type == "
|
|
66
|
-
torch_api =
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
if api_type == "NPU":
|
|
72
|
-
torch_api = NpuOPTemplate(api_name, None, False)
|
|
73
|
-
out = torch_api.forward(*args, **kwargs)
|
|
74
|
-
return out
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
from msprobe.core.common.const import FileCheckConst
|
|
5
|
+
from msprobe.core.common.file_utils import FileChecker
|
|
6
|
+
from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
|
|
7
|
+
from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
|
|
8
|
+
from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
|
|
9
|
+
from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
|
|
10
|
+
from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
|
|
11
|
+
|
|
12
|
+
hf_32_standard_api = ["conv1d", "conv2d"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Backward_Message:
|
|
16
|
+
MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
|
|
17
|
+
UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, skip backward."
|
|
18
|
+
NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class UtDataInfo:
|
|
22
|
+
def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list,
|
|
23
|
+
backward_message, rank=0):
|
|
24
|
+
self.bench_grad = bench_grad
|
|
25
|
+
self.device_grad = device_grad
|
|
26
|
+
self.device_output = device_output
|
|
27
|
+
self.bench_output = bench_output
|
|
28
|
+
self.grad_in = grad_in
|
|
29
|
+
self.in_fwd_data_list = in_fwd_data_list
|
|
30
|
+
self.backward_message = backward_message
|
|
31
|
+
self.rank = rank
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_validated_result_csv_path(result_csv_path, mode):
|
|
35
|
+
if mode not in ['result', 'detail']:
|
|
36
|
+
raise ValueError("The csv mode must be result or detail")
|
|
37
|
+
result_csv_path_checker = FileChecker(result_csv_path, FileCheckConst.FILE, ability=FileCheckConst.READ_WRITE_ABLE,
|
|
38
|
+
file_type=FileCheckConst.CSV_SUFFIX)
|
|
39
|
+
validated_result_csv_path = result_csv_path_checker.common_check()
|
|
40
|
+
if mode == 'result':
|
|
41
|
+
result_csv_name = os.path.basename(validated_result_csv_path)
|
|
42
|
+
pattern = r"^accuracy_checking_result_\d{14}\.csv$"
|
|
43
|
+
if not re.match(pattern, result_csv_name):
|
|
44
|
+
raise ValueError("When continue run ut, please do not modify the result csv name.")
|
|
45
|
+
return validated_result_csv_path
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_validated_details_csv_path(validated_result_csv_path):
|
|
49
|
+
result_csv_name = os.path.basename(validated_result_csv_path)
|
|
50
|
+
details_csv_name = result_csv_name.replace('result', 'details')
|
|
51
|
+
details_csv_path = os.path.join(os.path.dirname(validated_result_csv_path), details_csv_name)
|
|
52
|
+
details_csv_path_checker = FileChecker(details_csv_path, FileCheckConst.FILE,
|
|
53
|
+
ability=FileCheckConst.READ_WRITE_ABLE, file_type=FileCheckConst.CSV_SUFFIX)
|
|
54
|
+
validated_details_csv_path = details_csv_path_checker.common_check()
|
|
55
|
+
return validated_details_csv_path
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def exec_api(api_type, api_name, device, args, kwargs):
|
|
59
|
+
if api_type == "Functional":
|
|
60
|
+
torch_api = FunctionalOPTemplate(api_name, str, False)
|
|
61
|
+
if api_type == "Tensor":
|
|
62
|
+
torch_api = TensorOPTemplate(api_name, str, False)
|
|
63
|
+
if api_type == "Torch":
|
|
64
|
+
torch_api = TorchOPTemplate(api_name, str, False)
|
|
65
|
+
if api_type == "Aten":
|
|
66
|
+
torch_api = AtenOPTemplate(api_name, None, False)
|
|
67
|
+
if api_type == "NPU":
|
|
68
|
+
torch_api = NpuOPTemplate(api_name, None, False, device)
|
|
69
|
+
out = torch_api.forward(*args, **kwargs)
|
|
70
|
+
return out
|
|
@@ -1,202 +1,197 @@
|
|
|
1
|
-
import
|
|
2
|
-
import os.path
|
|
3
|
-
import time
|
|
4
|
-
import re
|
|
5
|
-
from
|
|
6
|
-
from
|
|
7
|
-
from
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
import
|
|
12
|
-
|
|
13
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.
|
|
14
|
-
from msprobe.pytorch.
|
|
15
|
-
from msprobe.
|
|
16
|
-
from msprobe.pytorch.common.utils import save_pt
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
self.
|
|
40
|
-
self.
|
|
41
|
-
self.
|
|
42
|
-
self.
|
|
43
|
-
self.
|
|
44
|
-
self.
|
|
45
|
-
|
|
46
|
-
self.
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
self.socket_manager
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
self.session_config.
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
self.socket_manager
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
self.
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
if
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
buffer
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
if timeout_ms > 0:
|
|
105
|
-
|
|
106
|
-
if
|
|
107
|
-
buffer =
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
if buffer == b"
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
self.
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
return
|
|
176
|
-
elif
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
if target_device == torch.device('cpu') or target_device == "cpu":
|
|
200
|
-
return ApiData(buffer.name, tuple(new_args), new_kwargs, new_results, buffer.step, buffer.rank)
|
|
201
|
-
else:
|
|
202
|
-
return ApiData(buffer.name, tuple(new_args), new_kwargs, buffer.result, buffer.step, buffer.rank)
|
|
1
|
+
import glob
|
|
2
|
+
import os.path
|
|
3
|
+
import time
|
|
4
|
+
import re
|
|
5
|
+
from multiprocessing import Queue
|
|
6
|
+
from typing import Optional, Union, Dict, Any
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
12
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
|
|
13
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
|
|
14
|
+
from msprobe.pytorch.common.utils import logger
|
|
15
|
+
from msprobe.core.common.file_utils import remove_path
|
|
16
|
+
from msprobe.pytorch.common.utils import save_api_data, load_api_data, save_pt, load_pt
|
|
17
|
+
|
|
18
|
+
BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class ATTLConfig:
|
|
23
|
+
is_benchmark_device: bool
|
|
24
|
+
connect_ip: str
|
|
25
|
+
connect_port: int
|
|
26
|
+
# storage_config
|
|
27
|
+
nfs_path: str = None
|
|
28
|
+
tls_path: str = None
|
|
29
|
+
check_sum: bool = True
|
|
30
|
+
queue_size: int = 50
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ATTL:
|
|
34
|
+
def __init__(self, session_id: str, session_config: ATTLConfig, need_dump=True) -> None:
|
|
35
|
+
self.session_id = session_id
|
|
36
|
+
self.session_config = session_config
|
|
37
|
+
self.logger = logger
|
|
38
|
+
self.socket_manager = None
|
|
39
|
+
self.data_queue = Queue(maxsize=50)
|
|
40
|
+
self.dequeue_list = []
|
|
41
|
+
self.message_end = False
|
|
42
|
+
self.kill_progress = False
|
|
43
|
+
self.check_attl_config()
|
|
44
|
+
if self.session_config.nfs_path:
|
|
45
|
+
self.nfs_path = self.session_config.nfs_path
|
|
46
|
+
elif self.session_config.is_benchmark_device:
|
|
47
|
+
|
|
48
|
+
self.socket_manager = TCPServer(self.session_config.connect_port,
|
|
49
|
+
self.data_queue,
|
|
50
|
+
self.session_config.check_sum,
|
|
51
|
+
self.session_config.tls_path)
|
|
52
|
+
self.socket_manager.start()
|
|
53
|
+
elif need_dump:
|
|
54
|
+
self.socket_manager = TCPClient(self.session_config.connect_ip,
|
|
55
|
+
self.session_config.connect_port,
|
|
56
|
+
self.session_config.check_sum,
|
|
57
|
+
self.session_config.tls_path)
|
|
58
|
+
self.socket_manager.start()
|
|
59
|
+
|
|
60
|
+
def check_attl_config(self):
|
|
61
|
+
if self.session_config.nfs_path:
|
|
62
|
+
if os.path.exists(self.session_config.nfs_path):
|
|
63
|
+
return
|
|
64
|
+
else:
|
|
65
|
+
raise Exception(f"nfs path {self.session_config.nfs_path} doesn't exists.")
|
|
66
|
+
ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
|
|
67
|
+
if not re.match(ipv4_pattern, self.session_config.connect_ip):
|
|
68
|
+
raise Exception(f"host {self.session_config.connect_ip} is invalid.")
|
|
69
|
+
if not (0 < self.session_config.connect_port <= 65535):
|
|
70
|
+
raise Exception(f"port {self.session_config.connect_port} is invalid.")
|
|
71
|
+
|
|
72
|
+
def stop_serve(self):
|
|
73
|
+
if isinstance(self.socket_manager, TCPServer):
|
|
74
|
+
self.socket_manager.stop()
|
|
75
|
+
|
|
76
|
+
def send(self, buffer: BufferType) -> None:
|
|
77
|
+
"""
|
|
78
|
+
npu major in 'send' (client)
|
|
79
|
+
"""
|
|
80
|
+
# know receiver receive and go next
|
|
81
|
+
if isinstance(buffer, ApiData):
|
|
82
|
+
buffer = move2target_device(buffer, torch.device('cpu'))
|
|
83
|
+
|
|
84
|
+
if 'device' in buffer.kwargs:
|
|
85
|
+
buffer.kwargs.pop('device')
|
|
86
|
+
rank = buffer.rank if hasattr(buffer, "rank") and buffer.rank is not None else 0
|
|
87
|
+
step = buffer.step if hasattr(buffer, "step") else 0
|
|
88
|
+
try:
|
|
89
|
+
io_buff = save_api_data(buffer)
|
|
90
|
+
except Exception as e:
|
|
91
|
+
self.logger.info(f"{buffer.name} can not be saved, skip: {e}")
|
|
92
|
+
return
|
|
93
|
+
data = io_buff.getvalue()
|
|
94
|
+
self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
|
|
95
|
+
|
|
96
|
+
def recv(self, timeout_ms=0) -> Optional[BufferType]:
|
|
97
|
+
buffer = None
|
|
98
|
+
while buffer is None:
|
|
99
|
+
if timeout_ms > 0:
|
|
100
|
+
time.sleep(timeout_ms / 1000.0)
|
|
101
|
+
if buffer is None and not self.data_queue.empty():
|
|
102
|
+
buffer = self.data_queue.get()
|
|
103
|
+
break
|
|
104
|
+
if buffer is None and timeout_ms > 0: # timeout is the only case we give up and return None
|
|
105
|
+
break
|
|
106
|
+
if self.message_end and self.data_queue.empty():
|
|
107
|
+
buffer = b"KILL_CONFIRM"
|
|
108
|
+
self.kill_progress = True
|
|
109
|
+
break
|
|
110
|
+
time.sleep(0.1) # waiting outside the lock before next attempt
|
|
111
|
+
if buffer is None:
|
|
112
|
+
# this is a result of a timeout
|
|
113
|
+
self.logger.info(f"RECEIVE API DATA TIMED OUT")
|
|
114
|
+
else:
|
|
115
|
+
if buffer == b"STOP_":
|
|
116
|
+
return "STOP_"
|
|
117
|
+
if buffer == b"KILL_":
|
|
118
|
+
self.message_end = True
|
|
119
|
+
return "STOP_"
|
|
120
|
+
if buffer == b"KILL_CONFIRM":
|
|
121
|
+
self.kill_progress = True
|
|
122
|
+
return "KILL_"
|
|
123
|
+
try:
|
|
124
|
+
buffer = load_api_data(buffer)
|
|
125
|
+
except Exception as e:
|
|
126
|
+
self.logger.warning("there is something error. please check it. %s", e)
|
|
127
|
+
if isinstance(buffer, bytes):
|
|
128
|
+
return None
|
|
129
|
+
if isinstance(buffer, str):
|
|
130
|
+
return buffer
|
|
131
|
+
|
|
132
|
+
return buffer
|
|
133
|
+
|
|
134
|
+
def upload(self, buffer: BufferType):
|
|
135
|
+
if isinstance(buffer, ApiData):
|
|
136
|
+
buffer = move2target_device(buffer, torch.device('cpu'))
|
|
137
|
+
file_path = os.path.join(self.session_config.nfs_path, buffer.name + ".pt")
|
|
138
|
+
else:
|
|
139
|
+
file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
save_pt(buffer, file_path)
|
|
143
|
+
except Exception as e:
|
|
144
|
+
self.logger.warning("there is something error in save_pt. please check it. %s", e)
|
|
145
|
+
|
|
146
|
+
def download(self):
|
|
147
|
+
buffer = None
|
|
148
|
+
cur_file = None
|
|
149
|
+
for file_type in ("start*", "*.pt", "end*"):
|
|
150
|
+
pattern = os.path.join(self.nfs_path, file_type)
|
|
151
|
+
files = glob.glob(pattern)
|
|
152
|
+
if len(files) > 0:
|
|
153
|
+
cur_file = files[0]
|
|
154
|
+
break
|
|
155
|
+
|
|
156
|
+
if cur_file is not None:
|
|
157
|
+
try:
|
|
158
|
+
buffer = load_pt(cur_file)
|
|
159
|
+
except Exception as e:
|
|
160
|
+
self.logger.warning("there is something error. please check it. %s", e)
|
|
161
|
+
remove_path(cur_file)
|
|
162
|
+
return buffer
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def move2device_exec(obj, device):
|
|
166
|
+
if isinstance(obj, (tuple, list)):
|
|
167
|
+
data_list = [move2device_exec(val, device) for val in obj]
|
|
168
|
+
return data_list if isinstance(obj, list) else tuple(data_list)
|
|
169
|
+
if isinstance(obj, dict):
|
|
170
|
+
return {key: move2device_exec(val, device) for key, val in obj.items()}
|
|
171
|
+
elif isinstance(obj, torch.Tensor):
|
|
172
|
+
obj = obj.detach()
|
|
173
|
+
if obj.device.type != device:
|
|
174
|
+
obj = obj.to(device)
|
|
175
|
+
return obj
|
|
176
|
+
elif "return_types" in str(type(obj)):
|
|
177
|
+
return move2device_exec(tuple(obj), device)
|
|
178
|
+
elif isinstance(obj, torch._C.device):
|
|
179
|
+
return torch.device(device)
|
|
180
|
+
else:
|
|
181
|
+
return obj
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def move2target_device(buffer: ApiData, target_device):
|
|
185
|
+
# handle args
|
|
186
|
+
new_args = move2device_exec(buffer.args, target_device)
|
|
187
|
+
|
|
188
|
+
# handle kwargs
|
|
189
|
+
new_kwargs = move2device_exec(buffer.kwargs, target_device)
|
|
190
|
+
|
|
191
|
+
# handle result
|
|
192
|
+
new_results = move2device_exec(buffer.result, target_device)
|
|
193
|
+
|
|
194
|
+
if target_device == torch.device('cpu') or target_device == "cpu":
|
|
195
|
+
return ApiData(buffer.name, tuple(new_args), new_kwargs, new_results, buffer.step, buffer.rank)
|
|
196
|
+
else:
|
|
197
|
+
return ApiData(buffer.name, tuple(new_args), new_kwargs, buffer.result, buffer.step, buffer.rank)
|