mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,454 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
# Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
import argparse
|
|
18
|
+
import json
|
|
19
|
+
import os
|
|
20
|
+
import re
|
|
21
|
+
import math
|
|
22
|
+
import numpy as np
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import binary_standard_api, absolute_standard_api, ulp_standard_api, thousandth_standard_api
|
|
27
|
+
from msprobe.core.common.file_utils import FileOpen, load_json, save_json
|
|
28
|
+
from msprobe.core.common.utils import check_file_or_directory_path, check_op_str_pattern_valid, is_int
|
|
29
|
+
from msprobe.core.common.const import Const, MonitorConst, MsgConst
|
|
30
|
+
from msprobe.core.common.log import logger
|
|
31
|
+
from msprobe.core.common.file_utils import make_dir
|
|
32
|
+
from msprobe.core.common.utils import recursion_depth_decorator
|
|
33
|
+
|
|
34
|
+
TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
|
|
35
|
+
TORCH_BOOL_TYPE = ["torch.bool"]
|
|
36
|
+
TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int",
|
|
37
|
+
"torch.int64", "torch.long"]
|
|
38
|
+
TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float",
|
|
39
|
+
"torch.float64", "torch.double"]
|
|
40
|
+
TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128",
|
|
41
|
+
"torch.cdouble"]
|
|
42
|
+
OPERATOR_TYPE = ("Functional", "Tensor", "Torch")
|
|
43
|
+
|
|
44
|
+
API_INFO = 2
|
|
45
|
+
FOUR_SEGMENT = 4
|
|
46
|
+
FIVE_SEGMENT = 5
|
|
47
|
+
DATA_NAME = "data_name"
|
|
48
|
+
API_MAX_LENGTH = 30
|
|
49
|
+
PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD]
|
|
50
|
+
DATAMODE_LIST = ["random_data", "real_data"]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class APIInfo:
|
|
54
|
+
def __init__(self, api_full_name, api_info_dict, backward_info=None):
|
|
55
|
+
self.api_full_name = api_full_name
|
|
56
|
+
self.api_info_dict = api_info_dict
|
|
57
|
+
self.backward_info = backward_info
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def api_type(self):
|
|
61
|
+
return self.api_full_name.split(Const.SEP, -1)[0]
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def from_json(cls, json_content, propagation):
|
|
65
|
+
forward_name, forward_dict = list(json_content.items())[0]
|
|
66
|
+
forward_info = cls(api_full_name=forward_name, api_info_dict=forward_dict)
|
|
67
|
+
|
|
68
|
+
if propagation == Const.BACKWARD:
|
|
69
|
+
backward_name, backward_dict = list(json_content.items())[1]
|
|
70
|
+
backward_info = cls(api_full_name=backward_name, api_info_dict=backward_dict)
|
|
71
|
+
forward_info.backward_info = backward_info
|
|
72
|
+
|
|
73
|
+
if not forward_info.is_supported_type():
|
|
74
|
+
raise ValueError(f"type {forward_info.api_type} of API is not supported!")
|
|
75
|
+
|
|
76
|
+
return forward_info
|
|
77
|
+
|
|
78
|
+
def is_supported_type(self):
|
|
79
|
+
return self.api_type in OPERATOR_TYPE
|
|
80
|
+
|
|
81
|
+
class CommonConfig:
|
|
82
|
+
def __init__(self, json_config):
|
|
83
|
+
self.dump_json_path = json_config.get('dump_json_path')
|
|
84
|
+
self.api_name = json_config.get('api_name')
|
|
85
|
+
self.extract_api_path = json_config.get('extract_api_path')
|
|
86
|
+
self.propagation = json_config.get('propagation')
|
|
87
|
+
self.data_mode = json_config.get('data_mode')
|
|
88
|
+
self.random_seed = json_config.get('random_seed')
|
|
89
|
+
self.iter_times = json_config.get('iter_times')
|
|
90
|
+
self._check_config()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def check_user_settings(self):
|
|
94
|
+
iter_t = self.iter_times
|
|
95
|
+
if iter_t <= 0:
|
|
96
|
+
raise ValueError("iter_times should be an integer bigger than zero!")
|
|
97
|
+
|
|
98
|
+
json_file = self.extract_api_path
|
|
99
|
+
propagation = self.propagation
|
|
100
|
+
|
|
101
|
+
json_content = load_json(json_file)
|
|
102
|
+
|
|
103
|
+
# ensure the dict is not empty
|
|
104
|
+
if not json_content:
|
|
105
|
+
raise ValueError(f'json file is empty!')
|
|
106
|
+
|
|
107
|
+
# ensure json_content is of type dict
|
|
108
|
+
if not isinstance(json_content, dict):
|
|
109
|
+
raise ValueError(f'content of json file is not a dict!')
|
|
110
|
+
|
|
111
|
+
# ensure the length of json_content is within allowed limits
|
|
112
|
+
if len(json_content) > API_INFO:
|
|
113
|
+
raise ValueError(f'json file has more than one API, the API only contains forward and backward info')
|
|
114
|
+
|
|
115
|
+
# Retrieve the first API name and dictionary
|
|
116
|
+
forward_item = next(iter(json_content.items()), None)
|
|
117
|
+
if not forward_item or not isinstance(forward_item[1], dict):
|
|
118
|
+
raise ValueError(f'Invalid forward API data in json_content!')
|
|
119
|
+
|
|
120
|
+
# if propagation is backward, ensure json file contains forward and backward info
|
|
121
|
+
if propagation == Const.BACKWARD and len(json_content) < API_INFO:
|
|
122
|
+
raise ValueError(f'Backward propagation requires contains forward and backward info!')
|
|
123
|
+
|
|
124
|
+
# if propagation is backward, ensure it has valid data
|
|
125
|
+
if propagation == Const.BACKWARD:
|
|
126
|
+
backward_item = list(json_content.items())[1]
|
|
127
|
+
if not isinstance(backward_item[1], dict):
|
|
128
|
+
raise ValueError(f'Invalid backward API data in json_content!')
|
|
129
|
+
|
|
130
|
+
return json_content
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _check_config(self):
|
|
134
|
+
if self.dump_json_path:
|
|
135
|
+
check_file_or_directory_path(self.dump_json_path)
|
|
136
|
+
if self.api_name:
|
|
137
|
+
check_op_str_pattern_valid(self.api_name)
|
|
138
|
+
if len(self.api_name) > API_MAX_LENGTH:
|
|
139
|
+
raise ValueError(f'API name {self.api_name} is too long!')
|
|
140
|
+
make_dir(os.path.dirname(self.extract_api_path))
|
|
141
|
+
if self.propagation and self.propagation not in PROPAGATION_LIST:
|
|
142
|
+
raise ValueError(f'propagation is invalid, it should be one of {PROPAGATION_LIST}')
|
|
143
|
+
if self.data_mode and self.data_mode not in DATAMODE_LIST:
|
|
144
|
+
raise ValueError(f'data_mode is invalid, it should be one of {DATAMODE_LIST}')
|
|
145
|
+
if not is_int(self.random_seed):
|
|
146
|
+
raise ValueError(f'random_seed is invalid, it should be an int')
|
|
147
|
+
if not is_int(self.iter_times):
|
|
148
|
+
raise ValueError(f'iter_times is invalid, it should be an int')
|
|
149
|
+
|
|
150
|
+
class APIExtractor:
|
|
151
|
+
def __init__(self, api_name, dump_json_path, output_file):
|
|
152
|
+
self.api_name = api_name
|
|
153
|
+
self.dump_json_path = dump_json_path
|
|
154
|
+
self.output_file = output_file
|
|
155
|
+
self.data = None
|
|
156
|
+
|
|
157
|
+
def extract_op(self):
|
|
158
|
+
self.data = load_json(self.dump_json_path)
|
|
159
|
+
new_data = {}
|
|
160
|
+
extract_key_pattern = re.compile(f"^{re.escape(self.api_name)}\..+")
|
|
161
|
+
real_data_path = self.data.get('dump_data_dir', '')
|
|
162
|
+
for key, value in self.data.get('data', {}).items():
|
|
163
|
+
if extract_key_pattern.match(key):
|
|
164
|
+
if real_data_path:
|
|
165
|
+
value = self.load_real_data_path(value, real_data_path)
|
|
166
|
+
new_data[key] = value
|
|
167
|
+
if not new_data:
|
|
168
|
+
logger.error(f"Error: The api '{self.api_name}' does not exist in the file.")
|
|
169
|
+
else:
|
|
170
|
+
save_json(self.output_file, new_data, indent=4)
|
|
171
|
+
logger.info(
|
|
172
|
+
f"The api '{self.api_name}' has been successfully extracted and saved in: {self.output_file}")
|
|
173
|
+
|
|
174
|
+
def load_real_data_path(self, value, dump_data_dir):
|
|
175
|
+
parameters = [Const.INPUT_ARGS, Const.GRAD_INPUT, Const.INPUT, Const.OUTPUT, Const.GRAD_OUTPUT]
|
|
176
|
+
for parameter in parameters:
|
|
177
|
+
for v in value.get(parameter, []):
|
|
178
|
+
if v is not None:
|
|
179
|
+
self.update_data_name(v, dump_data_dir)
|
|
180
|
+
return value
|
|
181
|
+
|
|
182
|
+
def update_data_name(self, data, dump_data_dir):
|
|
183
|
+
if isinstance(data, list):
|
|
184
|
+
for item in data:
|
|
185
|
+
self.update_data_name(item, dump_data_dir)
|
|
186
|
+
elif DATA_NAME in data:
|
|
187
|
+
data[DATA_NAME] = os.path.join(dump_data_dir, data[DATA_NAME])
|
|
188
|
+
|
|
189
|
+
class OperatorScriptGenerator:
|
|
190
|
+
def __init__(self, common_config, args_info_forward, kwargs_info_forward, args_info_backward):
|
|
191
|
+
self.common_config = common_config
|
|
192
|
+
self.args_info_forward = args_info_forward
|
|
193
|
+
self.kwargs_info_forward = kwargs_info_forward
|
|
194
|
+
self.args_info_backward = args_info_backward
|
|
195
|
+
|
|
196
|
+
@staticmethod
|
|
197
|
+
def get_compare_standard(api_name):
|
|
198
|
+
api_standard_map = {
|
|
199
|
+
"binary_standard_api": "CompareStandard.BINARY_EQUALITY_STANDARD",
|
|
200
|
+
"absolute_standard_api": "CompareStandard.ABSOLUTE_THRESHOLD_STANDARD",
|
|
201
|
+
"ulp_standard_api": "CompareStandard.ULP_ERROR_STANDARD",
|
|
202
|
+
"thousandth_standard_api": "CompareStandard.THOUSANDTH_STANDARD"
|
|
203
|
+
}
|
|
204
|
+
for standard_api, standard_value in api_standard_map.items():
|
|
205
|
+
if api_name in globals()[standard_api]:
|
|
206
|
+
return standard_value
|
|
207
|
+
return "CompareStandard.BENCHMARK_STANDARD"
|
|
208
|
+
|
|
209
|
+
@staticmethod
|
|
210
|
+
def extract_detailed_api_segments(full_api_name):
|
|
211
|
+
"""
|
|
212
|
+
Function Description:
|
|
213
|
+
Extract the name of the API.
|
|
214
|
+
Parameter:
|
|
215
|
+
full_api_name_with_direction_status: Full name of the API. Example: torch.matmul.0.forward.output.0
|
|
216
|
+
Return:
|
|
217
|
+
api_name: Name of api. Example: matmul, mul, etc.
|
|
218
|
+
full_api_name: Full name of api. Example: torch.matmul.0
|
|
219
|
+
direction_status: Direction status of api. Example: forward, backward, etc.
|
|
220
|
+
"""
|
|
221
|
+
api_parts = full_api_name.split(Const.SEP)
|
|
222
|
+
api_parts_length = len(api_parts)
|
|
223
|
+
api_type, api_name, api_order = None, None, None
|
|
224
|
+
if api_parts_length == FOUR_SEGMENT:
|
|
225
|
+
api_type, api_name, api_order, _ = api_parts
|
|
226
|
+
elif api_parts_length == FIVE_SEGMENT:
|
|
227
|
+
api_type, prefix, api_name, api_order, _ = api_parts
|
|
228
|
+
api_name = Const.SEP.join([prefix, api_name])
|
|
229
|
+
return api_type, api_name, api_order
|
|
230
|
+
|
|
231
|
+
def get_settings(self, api_full_name):
|
|
232
|
+
'''
|
|
233
|
+
internal_settings contain all information needed for the operator program.
|
|
234
|
+
keys:
|
|
235
|
+
api_full_name: api_type.api_name.ordinal_number
|
|
236
|
+
api_type: type of API, one of torch.nn.functional, torch.Tensor or Torch
|
|
237
|
+
api_name: name of API
|
|
238
|
+
ordinal_number: how many times the same api has been called
|
|
239
|
+
direction_status: forward
|
|
240
|
+
random_seed: if mode is random_data, random seed is random_seed
|
|
241
|
+
iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data, iter_times does not matter
|
|
242
|
+
args_element_assignment: code for args assignment
|
|
243
|
+
args_list_generator_device: code for generate args list on device
|
|
244
|
+
args_list_generator_bench: code for generate args list on bench
|
|
245
|
+
kwargs_value_assignment: code for kwargs assignment
|
|
246
|
+
kwargs_dict_generator_device: code for generate kwargs dict on device
|
|
247
|
+
kwargs_dict_generator_bench: code for generate kwargs dict on bench
|
|
248
|
+
'''
|
|
249
|
+
# Generate an internal setting dictionary based on user settings
|
|
250
|
+
# including API name, type, comparison standard, random seed, number of iterations and other information
|
|
251
|
+
internal_settings = {}
|
|
252
|
+
internal_settings["propagation"] = self.common_config.propagation
|
|
253
|
+
internal_settings["api_full_name"] = api_full_name
|
|
254
|
+
api_type, api_name, ordinal_number = self.extract_detailed_api_segments(api_full_name)
|
|
255
|
+
if api_type == "Functional":
|
|
256
|
+
internal_settings["api_type"] = "torch.nn.functional"
|
|
257
|
+
elif api_type == "Tensor":
|
|
258
|
+
internal_settings["api_type"] = "torch.Tensor"
|
|
259
|
+
else:
|
|
260
|
+
internal_settings["api_type"] = "torch"
|
|
261
|
+
internal_settings["api_name"] = api_name
|
|
262
|
+
internal_settings["compare_standard"] = self.get_compare_standard(api_name)
|
|
263
|
+
internal_settings["ordinal_number"] = ordinal_number
|
|
264
|
+
internal_settings["direction_status"] = self.common_config.propagation
|
|
265
|
+
internal_settings["random_seed"] = self.common_config.random_seed
|
|
266
|
+
if self.common_config.data_mode == "real_data":
|
|
267
|
+
internal_settings["iter_times"] = 1
|
|
268
|
+
else:
|
|
269
|
+
internal_settings["iter_times"] = self.common_config.iter_times
|
|
270
|
+
internal_settings["args_element_assignment"] = self.generate_args_element_assignment_code(self.args_info_forward)
|
|
271
|
+
internal_settings["args_list_generator_device"] = self.generate_args_list(self.args_info_forward, flag_device=True)
|
|
272
|
+
internal_settings["args_list_generator_bench"] = self.generate_args_list(self.args_info_forward, flag_device=False)
|
|
273
|
+
internal_settings["kwargs_value_assignment"] = self.generate_kwargs_value_assignment_code(self.kwargs_info_forward)
|
|
274
|
+
internal_settings["kwargs_dict_generator_device"] = self.generate_kwargs_dict(self.kwargs_info_forward, flag_device=True)
|
|
275
|
+
internal_settings["kwargs_dict_generator_bench"] = self.generate_kwargs_dict(self.kwargs_info_forward, flag_device=False)
|
|
276
|
+
if self.common_config.propagation == Const.BACKWARD:
|
|
277
|
+
internal_settings["args_element_assignment_backward"] = self.generate_args_element_assignment_code(
|
|
278
|
+
self.args_info_backward)
|
|
279
|
+
internal_settings["args_list_generator_device_backward"] = self.generate_args_list(self.args_info_backward, flag_device=True)
|
|
280
|
+
internal_settings["args_list_generator_bench_backward"] = self.generate_args_list(self.args_info_backward, flag_device=False)
|
|
281
|
+
else:
|
|
282
|
+
internal_settings["args_element_assignment_backward"] = ''
|
|
283
|
+
internal_settings["args_list_generator_device_backward"] = ''
|
|
284
|
+
internal_settings["args_list_generator_bench_backward"] = ''
|
|
285
|
+
|
|
286
|
+
return internal_settings
|
|
287
|
+
|
|
288
|
+
@recursion_depth_decorator("OpGenerator: OperatorScriptGenerator.recursive_args_element_assignment")
|
|
289
|
+
def recursive_args_element_assignment(self, args_info, name_number):
|
|
290
|
+
args_element_assignment = ""
|
|
291
|
+
for index, arg in enumerate(args_info):
|
|
292
|
+
if isinstance(arg, (list, tuple)):
|
|
293
|
+
new_args_element_assignment = self.recursive_args_element_assignment(arg, name_number + "_" + str(index))
|
|
294
|
+
args_element_assignment += new_args_element_assignment
|
|
295
|
+
else:
|
|
296
|
+
arg["parameter_name"] = "arg" + name_number + "_" + str(index)
|
|
297
|
+
args_element_assignment += " " + "arg_info" + name_number + "_" + str(index) + " = " + "{}".format(str(arg)) + MsgConst.SPECIAL_CHAR[0]
|
|
298
|
+
args_element_assignment += " " + "arg" + name_number + "_" + str(index) + " = " + "generate_data(arg_info" + name_number + "_" + str(index) + ")" + MsgConst.SPECIAL_CHAR[0]
|
|
299
|
+
return args_element_assignment
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def generate_args_element_assignment_code(self, args_info):
|
|
303
|
+
args_element_assignment = self.recursive_args_element_assignment(args_info, "")
|
|
304
|
+
return args_element_assignment
|
|
305
|
+
|
|
306
|
+
@recursion_depth_decorator("OpGenerator: OperatorScriptGenerator.recursive_args_list")
|
|
307
|
+
def recursive_args_list(self, args_info, flag_device=False, flag_bench=False):
|
|
308
|
+
args_list_generator = ""
|
|
309
|
+
for _, arg in enumerate(args_info):
|
|
310
|
+
if isinstance(arg, (list, tuple)):
|
|
311
|
+
(left_bracket, right_bracket) = ("[", "]") if isinstance(arg, list) else ("(", ")")
|
|
312
|
+
args_list_generator += left_bracket
|
|
313
|
+
new_args_list_generator = self.recursive_args_list(arg, flag_device=flag_device, flag_bench=flag_bench)
|
|
314
|
+
args_list_generator += new_args_list_generator
|
|
315
|
+
args_list_generator += right_bracket
|
|
316
|
+
else:
|
|
317
|
+
args_list_generator += arg.get("parameter_name")
|
|
318
|
+
if arg.get("type") in TENSOR_DATA_LIST:
|
|
319
|
+
if flag_device:
|
|
320
|
+
args_list_generator += ".to(device)"
|
|
321
|
+
if flag_bench:
|
|
322
|
+
args_list_generator += '.to(torch.device("cpu"))'
|
|
323
|
+
args_list_generator += ".to(RAISE_PRECISION.get(str(" + arg.get("parameter_name") + ".dtype), " + arg.get("parameter_name") + ".dtype))"
|
|
324
|
+
args_list_generator += Const.COMMA
|
|
325
|
+
return args_list_generator
|
|
326
|
+
|
|
327
|
+
def generate_args_list(self, args_info, flag_device):
|
|
328
|
+
if flag_device:
|
|
329
|
+
args_list_generator = self.recursive_args_list(args_info, flag_device=True)
|
|
330
|
+
else:
|
|
331
|
+
args_list_generator = self.recursive_args_list(args_info, flag_bench=True)
|
|
332
|
+
return args_list_generator
|
|
333
|
+
|
|
334
|
+
@recursion_depth_decorator("OpGenerator: OperatorScriptGenerator.recursive_kwargs_value_assignment")
|
|
335
|
+
def recursive_kwargs_value_assignment(self, info, key_name, name_number):
|
|
336
|
+
kwargs_value_assignment = ""
|
|
337
|
+
if isinstance(info, dict):
|
|
338
|
+
if info.get("type") == "torch.device" or info.get("type") == "torch.dtype":
|
|
339
|
+
kwargs_value_assignment += " " + "kwarg_" + key_name + name_number + " = " + info.get("value")
|
|
340
|
+
else:
|
|
341
|
+
kwargs_value_assignment += " " + "kwarg_info_" + key_name + name_number + " = " + "{}".format(str(info)) + MsgConst.SPECIAL_CHAR[0]
|
|
342
|
+
kwargs_value_assignment += " " + "kwarg_" + key_name + name_number + " = " + "generate_data(kwarg_info_" + key_name + name_number + ")" + MsgConst.SPECIAL_CHAR[0]
|
|
343
|
+
info["parameter_name"] = "kwarg_" + key_name + name_number
|
|
344
|
+
else:
|
|
345
|
+
for index, arg in enumerate(info):
|
|
346
|
+
new_kwargs_value_assignment = self.recursive_kwargs_value_assignment(arg, key_name, name_number + "_" + str(index))
|
|
347
|
+
kwargs_value_assignment += new_kwargs_value_assignment
|
|
348
|
+
return kwargs_value_assignment
|
|
349
|
+
|
|
350
|
+
def generate_kwargs_value_assignment_code(self, kwargs_info):
|
|
351
|
+
kwargs_value_assignment = ""
|
|
352
|
+
for key, value in kwargs_info.items():
|
|
353
|
+
kwargs_value_assignment += self.recursive_kwargs_value_assignment(value, key, "")
|
|
354
|
+
return kwargs_value_assignment
|
|
355
|
+
|
|
356
|
+
@recursion_depth_decorator("OpGenerator: OperatorScriptGenerator.recursive_kwargs_dict")
|
|
357
|
+
def recursive_kwargs_dict(self, info, flag_device=False, flag_bench=False):
|
|
358
|
+
kwargs_dict_generator = ""
|
|
359
|
+
if isinstance(info, dict):
|
|
360
|
+
kwargs_dict_generator += info.get("parameter_name")
|
|
361
|
+
if info.get("type") in TENSOR_DATA_LIST:
|
|
362
|
+
if flag_device:
|
|
363
|
+
kwargs_dict_generator += ".to(device)"
|
|
364
|
+
if flag_bench:
|
|
365
|
+
kwargs_dict_generator += '.to(torch.device("cpu"))'
|
|
366
|
+
kwargs_dict_generator += ".to(RAISE_PRECISION.get(str(" + info.get("parameter_name") + ".dtype), " + info.get("parameter_name") + ".dtype))"
|
|
367
|
+
else:
|
|
368
|
+
(left_bracket, right_bracket) = ("[", "]") if isinstance(info, list) else ("(", ")")
|
|
369
|
+
kwargs_dict_generator += left_bracket
|
|
370
|
+
for arg in info:
|
|
371
|
+
kwargs_dict_generator += self.recursive_kwargs_dict(arg, flag_device=flag_device, flag_bench=flag_bench)
|
|
372
|
+
kwargs_dict_generator += Const.COMMA
|
|
373
|
+
kwargs_dict_generator += right_bracket
|
|
374
|
+
return kwargs_dict_generator
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def generate_kwargs_dict(self, kwargs_info, flag_device):
|
|
378
|
+
kwargs_dict_generator = ""
|
|
379
|
+
for key, value in kwargs_info.items():
|
|
380
|
+
kwargs_dict_generator += '"' + key + '"' + MonitorConst.VPP_SEP
|
|
381
|
+
if flag_device:
|
|
382
|
+
kwargs_dict_generator += self.recursive_kwargs_dict(value, flag_device=True) + Const.COMMA
|
|
383
|
+
else:
|
|
384
|
+
kwargs_dict_generator += self.recursive_kwargs_dict(value, flag_bench=True) + Const.COMMA
|
|
385
|
+
return kwargs_dict_generator
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def op_generator_parser(parser):
|
|
390
|
+
parser.add_argument("-i", "--config_input", dest="config_input", default='', type=str,
|
|
391
|
+
help="<Optional> Path of config json file", required=True)
|
|
392
|
+
parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str,
|
|
393
|
+
help="<Required> Path of extract api_name.json.",
|
|
394
|
+
required=True)
|
|
395
|
+
|
|
396
|
+
def parse_json_config(json_file_path):
|
|
397
|
+
if not json_file_path:
|
|
398
|
+
config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
399
|
+
json_file_path = os.path.join(config_dir, "config.json")
|
|
400
|
+
json_config = load_json(json_file_path)
|
|
401
|
+
common_config = CommonConfig(json_config)
|
|
402
|
+
return common_config
|
|
403
|
+
|
|
404
|
+
def main():
|
|
405
|
+
parser = argparse.ArgumentParser()
|
|
406
|
+
op_generator_parser(parser)
|
|
407
|
+
cmd_args = parser.parse_args()
|
|
408
|
+
|
|
409
|
+
common_config = parse_json_config(cmd_args.config_input)
|
|
410
|
+
|
|
411
|
+
if common_config.dump_json_path:
|
|
412
|
+
api_extract = APIExtractor(common_config.api_name, common_config.dump_json_path, common_config.extract_api_path)
|
|
413
|
+
api_extract.extract_op()
|
|
414
|
+
check_file_or_directory_path(common_config.extract_api_path)
|
|
415
|
+
check_file_or_directory_path(cmd_args.api_output_path, isdir=True)
|
|
416
|
+
json_content = common_config.check_user_settings()
|
|
417
|
+
api_info = APIInfo.from_json(json_content, common_config.propagation)
|
|
418
|
+
|
|
419
|
+
if common_config.propagation == Const.BACKWARD:
|
|
420
|
+
# read and check json
|
|
421
|
+
api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict
|
|
422
|
+
api_full_name_backward, api_info_dict_backward = (api_info.backward_info.api_full_name,
|
|
423
|
+
api_info.backward_info.api_info_dict)
|
|
424
|
+
args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS)
|
|
425
|
+
kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS)
|
|
426
|
+
if Const.GRAD_INPUT in api_info_dict_backward:
|
|
427
|
+
args_info_backward = api_info_dict_backward.get(Const.GRAD_INPUT)
|
|
428
|
+
elif Const.INPUT in api_info_dict_backward:
|
|
429
|
+
args_info_backward = api_info_dict_backward.get(Const.INPUT)
|
|
430
|
+
op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, args_info_backward)
|
|
431
|
+
internal_settings = op_generate.get_settings(api_full_name_backward)
|
|
432
|
+
else:
|
|
433
|
+
# read and check json
|
|
434
|
+
api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict
|
|
435
|
+
args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS)
|
|
436
|
+
kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS)
|
|
437
|
+
op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, None)
|
|
438
|
+
internal_settings = op_generate.get_settings(api_full_name_forward)
|
|
439
|
+
|
|
440
|
+
template_path = os.path.join(os.path.dirname(__file__), "operator_replication.template")
|
|
441
|
+
operator_script_path = os.path.join(cmd_args.api_output_path, "{0}.py".format(internal_settings.get("api_full_name")))
|
|
442
|
+
|
|
443
|
+
try:
|
|
444
|
+
with FileOpen(template_path, 'r') as ftemp, FileOpen(operator_script_path, 'w') as fout:
|
|
445
|
+
code_template = ftemp.read()
|
|
446
|
+
fout.write(code_template.format(**internal_settings))
|
|
447
|
+
except OSError:
|
|
448
|
+
logger.error(f"Failed to open file. Please check file {template_path} or {operator_script_path}.")
|
|
449
|
+
|
|
450
|
+
logger.info(f"Generate operator script successfully and the name is {operator_script_path}.")
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
if __name__ == "__main__":
|
|
454
|
+
main()
|