mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,460 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
# 标准库
|
|
19
|
+
import argparse
|
|
20
|
+
import json
|
|
21
|
+
import os
|
|
22
|
+
import re
|
|
23
|
+
import string
|
|
24
|
+
|
|
25
|
+
# 应用程序自定义模块
|
|
26
|
+
from msprobe.core.common.file_utils import (
|
|
27
|
+
FileOpen,
|
|
28
|
+
load_json,
|
|
29
|
+
save_json,
|
|
30
|
+
make_dir,
|
|
31
|
+
change_mode,
|
|
32
|
+
)
|
|
33
|
+
from msprobe.core.common.utils import (
|
|
34
|
+
check_file_or_directory_path,
|
|
35
|
+
check_op_str_pattern_valid,
|
|
36
|
+
is_int,
|
|
37
|
+
)
|
|
38
|
+
from msprobe.core.common.const import Const, MonitorConst, MsgConst, FileCheckConst
|
|
39
|
+
from msprobe.core.common.log import logger
|
|
40
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
41
|
+
|
|
42
|
+
OPERATOR_TYPE = ("Functional", "Tensor", "Torch", "Mint")
|
|
43
|
+
|
|
44
|
+
API_INFO = 2
|
|
45
|
+
FOUR_SEGMENT = 4
|
|
46
|
+
FIVE_SEGMENT = 5
|
|
47
|
+
DATA_NAME = "data_name"
|
|
48
|
+
API_MAX_LENGTH = 30
|
|
49
|
+
PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD]
|
|
50
|
+
DATAMODE_LIST = ["random_data", "real_data"]
|
|
51
|
+
ITER_MAX_TIMES = 1000
|
|
52
|
+
FRAMEWORK = 'framework'
|
|
53
|
+
REAL_DATA_PATH = 'real_data_path'
|
|
54
|
+
EXCLUED = {FRAMEWORK, REAL_DATA_PATH}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class APIInfo:
|
|
58
|
+
def __init__(self, api_full_name, api_info_dict, backward_info=None):
|
|
59
|
+
self.api_full_name = api_full_name
|
|
60
|
+
self.api_info_dict = api_info_dict
|
|
61
|
+
self.backward_info = backward_info
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def api_type(self):
|
|
65
|
+
return self.api_full_name.split(Const.SEP, -1)[0]
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def from_json(cls, json_content, propagation):
|
|
69
|
+
forward_name, forward_dict = list(json_content.items())[0]
|
|
70
|
+
forward_info = cls(api_full_name=forward_name, api_info_dict=forward_dict)
|
|
71
|
+
|
|
72
|
+
if propagation == Const.BACKWARD:
|
|
73
|
+
backward_name, backward_dict = list(json_content.items())[1]
|
|
74
|
+
backward_info = cls(api_full_name=backward_name, api_info_dict=backward_dict)
|
|
75
|
+
forward_info.backward_info = backward_info
|
|
76
|
+
|
|
77
|
+
if not forward_info.is_supported_type():
|
|
78
|
+
raise ValueError(f"type {forward_info.api_type} of API is not supported!")
|
|
79
|
+
|
|
80
|
+
return forward_info
|
|
81
|
+
|
|
82
|
+
def is_supported_type(self):
|
|
83
|
+
return self.api_type in OPERATOR_TYPE
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class CommonConfig:
|
|
87
|
+
def __init__(self, json_config):
|
|
88
|
+
self.dump_json_path = json_config.get('dump_json_path')
|
|
89
|
+
self.api_name = json_config.get('api_name')
|
|
90
|
+
self.extract_api_path = json_config.get('extract_api_path')
|
|
91
|
+
self.propagation = json_config.get('propagation')
|
|
92
|
+
self.data_mode = json_config.get('data_mode')
|
|
93
|
+
self.random_seed = json_config.get('random_seed')
|
|
94
|
+
self.iter_times = json_config.get('iter_times')
|
|
95
|
+
self._check_config()
|
|
96
|
+
|
|
97
|
+
def check_user_settings(self):
|
|
98
|
+
iter_t = self.iter_times
|
|
99
|
+
if iter_t <= 0 or iter_t > ITER_MAX_TIMES:
|
|
100
|
+
raise ValueError(f"iter_times should be range from 1 to {ITER_MAX_TIMES}.")
|
|
101
|
+
|
|
102
|
+
json_file = self.extract_api_path
|
|
103
|
+
propagation = self.propagation
|
|
104
|
+
|
|
105
|
+
json_content = load_json(json_file)
|
|
106
|
+
|
|
107
|
+
# ensure the dict is not empty
|
|
108
|
+
if not json_content:
|
|
109
|
+
raise ValueError(f'json file is empty!')
|
|
110
|
+
|
|
111
|
+
# ensure json_content is of type dict
|
|
112
|
+
if not isinstance(json_content, dict):
|
|
113
|
+
raise ValueError(f'content of json file is not a dict!')
|
|
114
|
+
|
|
115
|
+
# ensure the length of json_content is within allowed limits
|
|
116
|
+
|
|
117
|
+
filtered = {k: v for k, v in json_content.items() if k not in EXCLUED}
|
|
118
|
+
|
|
119
|
+
if not filtered:
|
|
120
|
+
raise ValueError(f'json file is empty!')
|
|
121
|
+
|
|
122
|
+
if len(filtered) > API_INFO:
|
|
123
|
+
raise ValueError(f'json file has more than one API, the API only contains forward and backward info')
|
|
124
|
+
|
|
125
|
+
is_forward_phase = propagation == Const.FORWARD
|
|
126
|
+
|
|
127
|
+
is_exact_api_count = len(filtered) == API_INFO
|
|
128
|
+
|
|
129
|
+
all_keys_forward = all(k.endswith('forward') for k in filtered)
|
|
130
|
+
|
|
131
|
+
if is_forward_phase and is_exact_api_count and all_keys_forward:
|
|
132
|
+
raise ValueError(
|
|
133
|
+
"json file has more than one API, the API only contains forward info。"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Retrieve the first API name and dictionary
|
|
137
|
+
forward_item = next(iter(json_content.items()), None)
|
|
138
|
+
if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]:
|
|
139
|
+
raise ValueError(f'Invalid forward API data in json_content!')
|
|
140
|
+
|
|
141
|
+
# if propagation is backward, ensure json file contains forward and backward info
|
|
142
|
+
if propagation == Const.BACKWARD and len(filtered) < API_INFO:
|
|
143
|
+
raise ValueError(f'Backward propagation requires contains forward and backward info!')
|
|
144
|
+
|
|
145
|
+
# if propagation is backward, ensure it has valid data
|
|
146
|
+
if propagation == Const.BACKWARD:
|
|
147
|
+
backward_item = list(json_content.items())[1]
|
|
148
|
+
if not isinstance(backward_item[1], dict) or not backward_item[1]:
|
|
149
|
+
raise ValueError(f'Invalid backward API data in json_content!')
|
|
150
|
+
|
|
151
|
+
return json_content
|
|
152
|
+
|
|
153
|
+
def _check_config(self):
|
|
154
|
+
if self.dump_json_path:
|
|
155
|
+
check_file_or_directory_path(self.dump_json_path)
|
|
156
|
+
if self.api_name:
|
|
157
|
+
check_op_str_pattern_valid(self.api_name)
|
|
158
|
+
if len(self.api_name) > API_MAX_LENGTH:
|
|
159
|
+
raise ValueError(f'API name {self.api_name} is too long!')
|
|
160
|
+
make_dir(os.path.dirname(self.extract_api_path))
|
|
161
|
+
if self.propagation and self.propagation not in PROPAGATION_LIST:
|
|
162
|
+
raise ValueError(f'propagation is invalid, it should be one of {PROPAGATION_LIST}')
|
|
163
|
+
if self.data_mode and self.data_mode not in DATAMODE_LIST:
|
|
164
|
+
raise ValueError(f'data_mode is invalid, it should be one of {DATAMODE_LIST}')
|
|
165
|
+
if not is_int(self.random_seed):
|
|
166
|
+
raise ValueError(f'random_seed is invalid, it should be an int')
|
|
167
|
+
if not is_int(self.iter_times):
|
|
168
|
+
raise ValueError(f'iter_times is invalid, it should be an int')
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class APIExtractor:
|
|
172
|
+
def __init__(self, api_name, dump_json_path, output_file):
|
|
173
|
+
self.api_name = api_name
|
|
174
|
+
self.dump_json_path = dump_json_path
|
|
175
|
+
self.output_file = output_file
|
|
176
|
+
self.data = None
|
|
177
|
+
self.framework = None
|
|
178
|
+
self.real_data_path = None
|
|
179
|
+
|
|
180
|
+
def extract_op(self):
|
|
181
|
+
self.data = load_json(self.dump_json_path)
|
|
182
|
+
# 拿到 framework
|
|
183
|
+
self.framework = self.data.get(FRAMEWORK, None)
|
|
184
|
+
|
|
185
|
+
new_data = {}
|
|
186
|
+
extract_key_pattern = re.compile(f"^{re.escape(self.api_name)}\..+") # 修改为只要包含或等于apiname即可,不需要是只包含
|
|
187
|
+
|
|
188
|
+
self.real_data_path = self.data.get('dump_data_dir', '')
|
|
189
|
+
|
|
190
|
+
for key, value in self.data.get('data', {}).items():
|
|
191
|
+
if extract_key_pattern.match(key):
|
|
192
|
+
if self.real_data_path:
|
|
193
|
+
value = self.load_real_data_path(value, self.real_data_path)
|
|
194
|
+
new_data[key] = value
|
|
195
|
+
|
|
196
|
+
if self.real_data_path is not None:
|
|
197
|
+
new_data[REAL_DATA_PATH] = self.real_data_path
|
|
198
|
+
|
|
199
|
+
# 把 framework 加进去
|
|
200
|
+
if self.framework is not None:
|
|
201
|
+
new_data[FRAMEWORK] = self.framework
|
|
202
|
+
if not new_data:
|
|
203
|
+
logger.warning(f"Warning: The api '{self.api_name}' does not exist in the file.")
|
|
204
|
+
else:
|
|
205
|
+
save_json(self.output_file, new_data, indent=4)
|
|
206
|
+
logger.info(
|
|
207
|
+
f"The api '{self.api_name}' has been successfully extracted and saved in: {self.output_file}")
|
|
208
|
+
|
|
209
|
+
def load_real_data_path(self, value, dump_data_dir):
|
|
210
|
+
parameters = [Const.INPUT_ARGS, Const.GRAD_INPUT, Const.INPUT, Const.OUTPUT, Const.GRAD_OUTPUT]
|
|
211
|
+
for parameter in parameters:
|
|
212
|
+
for v in value.get(parameter, []):
|
|
213
|
+
if v is not None:
|
|
214
|
+
self.update_data_name(v, dump_data_dir)
|
|
215
|
+
return value
|
|
216
|
+
|
|
217
|
+
@recursion_depth_decorator("OpGenerator: APIExtractor.update_data_name")
|
|
218
|
+
def update_data_name(self, data, dump_data_dir):
|
|
219
|
+
if isinstance(data, list):
|
|
220
|
+
for item in data:
|
|
221
|
+
self.update_data_name(item, dump_data_dir)
|
|
222
|
+
elif DATA_NAME in data:
|
|
223
|
+
data[DATA_NAME] = os.path.join(dump_data_dir, data[DATA_NAME])
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class OperatorScriptGenerator:
|
|
227
|
+
def __init__(self, common_config, args_info_forward, kwargs_info_forward, args_info_backward):
|
|
228
|
+
self.common_config = common_config
|
|
229
|
+
self.args_info_forward = args_info_forward
|
|
230
|
+
self.kwargs_info_forward = kwargs_info_forward
|
|
231
|
+
self.args_info_backward = args_info_backward
|
|
232
|
+
|
|
233
|
+
@staticmethod
|
|
234
|
+
def extract_detailed_api_segments(full_api_name):
|
|
235
|
+
"""
|
|
236
|
+
Function Description:
|
|
237
|
+
Extract the name of the API.
|
|
238
|
+
Parameter:
|
|
239
|
+
full_api_name_with_direction_status: Full name of the API. Example: torch.matmul.0.forward.output.0
|
|
240
|
+
Return:
|
|
241
|
+
api_name: Name of api. Example: matmul, mul, etc.
|
|
242
|
+
full_api_name: Full name of api. Example: torch.matmul.0
|
|
243
|
+
direction_status: Direction status of api. Example: forward, backward, etc.
|
|
244
|
+
"""
|
|
245
|
+
api_parts = full_api_name.split(Const.SEP)
|
|
246
|
+
api_parts_length = len(api_parts)
|
|
247
|
+
api_type, api_name, api_order = None, None, None
|
|
248
|
+
if api_parts_length == FOUR_SEGMENT:
|
|
249
|
+
api_type, api_name, api_order, _ = api_parts
|
|
250
|
+
elif api_parts_length == FIVE_SEGMENT:
|
|
251
|
+
api_type, prefix, api_name, api_order, _ = api_parts
|
|
252
|
+
api_name = Const.SEP.join([prefix, api_name])
|
|
253
|
+
return api_type, api_name, api_order
|
|
254
|
+
|
|
255
|
+
@staticmethod
|
|
256
|
+
def generate_forward_inputs_code(args_info):
|
|
257
|
+
names = []
|
|
258
|
+
|
|
259
|
+
def collect(info):
|
|
260
|
+
if isinstance(info, dict):
|
|
261
|
+
names.append(info["parameter_name"])
|
|
262
|
+
else:
|
|
263
|
+
for sub in info:
|
|
264
|
+
collect(sub)
|
|
265
|
+
|
|
266
|
+
collect(args_info)
|
|
267
|
+
|
|
268
|
+
return (
|
|
269
|
+
" forward_inputs = [\n"
|
|
270
|
+
" ComputeElement(parameter=info)\n"
|
|
271
|
+
" for info in (" + ", ".join(names) + ")\n"
|
|
272
|
+
" ]\n"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
@staticmethod
|
|
276
|
+
def generate_kwargs_compute_element_dict_code():
|
|
277
|
+
return (
|
|
278
|
+
" # ---- 构造 kwargs 对应的 ComputeElement 字典 ----\n"
|
|
279
|
+
" kwargs_compute_element_dict = {\n"
|
|
280
|
+
" key_str: ComputeElement(compute_element_info=compute_element_info)\n"
|
|
281
|
+
" for key_str, compute_element_info in kwargs_device.items()\n"
|
|
282
|
+
" }\n"
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
@staticmethod
|
|
286
|
+
def generate_gradient_inputs_code(args_info_backward):
|
|
287
|
+
names = []
|
|
288
|
+
|
|
289
|
+
def collect(info):
|
|
290
|
+
if isinstance(info, dict):
|
|
291
|
+
names.append(info["parameter_name"])
|
|
292
|
+
else:
|
|
293
|
+
for sub in info:
|
|
294
|
+
collect(sub)
|
|
295
|
+
|
|
296
|
+
collect(args_info_backward)
|
|
297
|
+
|
|
298
|
+
return (
|
|
299
|
+
" # —— 构造反向梯度 ComputeElement 列表 —— #\n"
|
|
300
|
+
" gradient_inputs = [\n"
|
|
301
|
+
" ComputeElement(parameter=info)\n"
|
|
302
|
+
" for info in (" + ", ".join(names) + ")\n"
|
|
303
|
+
" ]\n"
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
def get_settings(self, api_full_name):
|
|
307
|
+
'''
|
|
308
|
+
internal_settings contain all information needed for the operator program.
|
|
309
|
+
keys:
|
|
310
|
+
api_full_name: api_type.api_name.ordinal_number
|
|
311
|
+
api_type: type of API, one of torch.nn.functional, torch.Tensor or Torch
|
|
312
|
+
api_name: name of API
|
|
313
|
+
ordinal_number: how many times the same api has been called
|
|
314
|
+
direction_status: forward
|
|
315
|
+
random_seed: if mode is random_data, random seed is random_seed
|
|
316
|
+
iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data,
|
|
317
|
+
iter_times does not matter
|
|
318
|
+
args_element_assignment: code for args assignment
|
|
319
|
+
args_list_generator_device: code for generate args list on device
|
|
320
|
+
args_list_generator_bench: code for generate args list on bench
|
|
321
|
+
kwargs_value_assignment: code for kwargs assignment
|
|
322
|
+
kwargs_dict_generator_device: code for generate kwargs dict on device
|
|
323
|
+
kwargs_dict_generator_bench: code for generate kwargs dict on bench
|
|
324
|
+
'''
|
|
325
|
+
# Generate an internal setting dictionary based on user settings
|
|
326
|
+
# including API name, type, comparison standard, random seed, number of iterations and other information
|
|
327
|
+
internal_settings = {}
|
|
328
|
+
internal_settings["propagation"] = self.common_config.propagation
|
|
329
|
+
internal_settings["api_full_name"] = api_full_name
|
|
330
|
+
api_type, api_name, ordinal_number = self.extract_detailed_api_segments(api_full_name)
|
|
331
|
+
if api_type == "Functional":
|
|
332
|
+
internal_settings["api_type"] = "torch.nn.functional"
|
|
333
|
+
elif api_type == "Tensor":
|
|
334
|
+
internal_settings["api_type"] = "torch.Tensor"
|
|
335
|
+
else:
|
|
336
|
+
internal_settings["api_type"] = "torch"
|
|
337
|
+
internal_settings["api_name"] = api_name
|
|
338
|
+
internal_settings["ordinal_number"] = ordinal_number
|
|
339
|
+
internal_settings["direction_status"] = self.common_config.propagation
|
|
340
|
+
internal_settings["random_seed"] = self.common_config.random_seed
|
|
341
|
+
internal_settings["data_mode"] = self.common_config.data_mode
|
|
342
|
+
if self.common_config.data_mode == "real_data":
|
|
343
|
+
internal_settings["iter_times"] = 1
|
|
344
|
+
else:
|
|
345
|
+
internal_settings["iter_times"] = self.common_config.iter_times
|
|
346
|
+
|
|
347
|
+
internal_settings["args_info_forward"] = self.args_info_forward
|
|
348
|
+
internal_settings["kwargs_info_forward"] = self.kwargs_info_forward
|
|
349
|
+
internal_settings["args_info_backward"] = self.args_info_backward
|
|
350
|
+
|
|
351
|
+
return internal_settings
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def _op_generator_parser(parser):
|
|
355
|
+
parser.add_argument("-i", "--config_input", dest="config_input", type=str,
|
|
356
|
+
help="<Required> Path of config json file", required=True)
|
|
357
|
+
parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str,
|
|
358
|
+
help="<Required> Path of extract api_name.json.", required=True)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def parse_json_config(json_file_path):
|
|
362
|
+
if not json_file_path:
|
|
363
|
+
raise Exception("config_input path can not be empty, please check.")
|
|
364
|
+
json_config = load_json(json_file_path)
|
|
365
|
+
common_config = CommonConfig(json_config)
|
|
366
|
+
return common_config
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def _run_operator_generate_commond(cmd_args):
|
|
370
|
+
common_config = parse_json_config(cmd_args.config_input)
|
|
371
|
+
|
|
372
|
+
if common_config.dump_json_path:
|
|
373
|
+
api_extract = APIExtractor(common_config.api_name, common_config.dump_json_path, common_config.extract_api_path)
|
|
374
|
+
api_extract.extract_op()
|
|
375
|
+
framework = api_extract.framework
|
|
376
|
+
real_data_path = api_extract.real_data_path
|
|
377
|
+
check_file_or_directory_path(common_config.extract_api_path)
|
|
378
|
+
check_file_or_directory_path(cmd_args.api_output_path, isdir=True)
|
|
379
|
+
json_content = common_config.check_user_settings()
|
|
380
|
+
api_info = APIInfo.from_json(json_content, common_config.propagation)
|
|
381
|
+
|
|
382
|
+
if common_config.propagation == Const.BACKWARD:
|
|
383
|
+
# read and check json
|
|
384
|
+
api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict
|
|
385
|
+
api_full_name_backward, api_info_dict_backward = (api_info.backward_info.api_full_name,
|
|
386
|
+
api_info.backward_info.api_info_dict)
|
|
387
|
+
args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS)
|
|
388
|
+
kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS)
|
|
389
|
+
if Const.GRAD_INPUT in api_info_dict_backward:
|
|
390
|
+
args_info_backward = api_info_dict_backward.get(Const.GRAD_INPUT)
|
|
391
|
+
elif Const.INPUT in api_info_dict_backward:
|
|
392
|
+
args_info_backward = api_info_dict_backward.get(Const.INPUT)
|
|
393
|
+
op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, args_info_backward)
|
|
394
|
+
internal_settings = op_generate.get_settings(api_full_name_backward)
|
|
395
|
+
internal_settings[FRAMEWORK] = framework
|
|
396
|
+
internal_settings[REAL_DATA_PATH] = real_data_path
|
|
397
|
+
else:
|
|
398
|
+
# read and check json
|
|
399
|
+
api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict
|
|
400
|
+
|
|
401
|
+
args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS)
|
|
402
|
+
|
|
403
|
+
kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS)
|
|
404
|
+
|
|
405
|
+
op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, None)
|
|
406
|
+
internal_settings = op_generate.get_settings(api_full_name_forward)
|
|
407
|
+
internal_settings[FRAMEWORK] = framework
|
|
408
|
+
internal_settings[REAL_DATA_PATH] = real_data_path
|
|
409
|
+
|
|
410
|
+
template_path = os.path.join(os.path.dirname(__file__), "operator_replication.template")
|
|
411
|
+
operator_script_path = os.path.join(cmd_args.api_output_path,
|
|
412
|
+
"{0}.py".format(internal_settings.get("api_full_name")))
|
|
413
|
+
|
|
414
|
+
class SafeDict(dict):
|
|
415
|
+
def __missing__(self, key):
|
|
416
|
+
# leave {key} in the output if it’s not in the dict
|
|
417
|
+
return '{' + key + '}'
|
|
418
|
+
|
|
419
|
+
class RobustFormatter(string.Formatter):
|
|
420
|
+
def vformat(self, format_string, args, kwargs):
|
|
421
|
+
result = []
|
|
422
|
+
# parse() 会把文本和每个占位符拆开
|
|
423
|
+
for literal, field_name, format_spec, conversion in self.parse(format_string):
|
|
424
|
+
# 输出字面文本
|
|
425
|
+
result.append(literal)
|
|
426
|
+
if field_name is None:
|
|
427
|
+
continue
|
|
428
|
+
try:
|
|
429
|
+
# 正常获取变量并格式化
|
|
430
|
+
obj, _ = self.get_field(field_name, args, kwargs)
|
|
431
|
+
if conversion:
|
|
432
|
+
obj = self.convert_field(obj, conversion)
|
|
433
|
+
result.append(self.format_field(obj, format_spec))
|
|
434
|
+
except Exception:
|
|
435
|
+
# 不管是 KeyError 还是 ValueError,都原样回写 {field_name[:format_spec]}
|
|
436
|
+
placeholder = '{' + field_name
|
|
437
|
+
if conversion:
|
|
438
|
+
placeholder += '!' + conversion
|
|
439
|
+
if format_spec:
|
|
440
|
+
placeholder += ':' + format_spec
|
|
441
|
+
placeholder += '}'
|
|
442
|
+
result.append(placeholder)
|
|
443
|
+
return ''.join(result)
|
|
444
|
+
|
|
445
|
+
fmt = RobustFormatter()
|
|
446
|
+
with FileOpen(template_path, 'r') as ftemp, FileOpen(operator_script_path, 'w') as fout:
|
|
447
|
+
code_template = ftemp.read()
|
|
448
|
+
# 这里用 fmt.format,不用 format_map
|
|
449
|
+
fout.write(fmt.format(code_template, **internal_settings))
|
|
450
|
+
|
|
451
|
+
change_mode(operator_script_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
452
|
+
|
|
453
|
+
logger.info(f"Generate operator script successfully and the name is {operator_script_path}.")
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
if __name__ == "__main__":
|
|
457
|
+
parser = argparse.ArgumentParser()
|
|
458
|
+
_op_generator_parser(parser)
|
|
459
|
+
cmd_args = parser.parse_args()
|
|
460
|
+
_run_operator_generate_commond(cmd_args)
|