mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -12,13 +12,16 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
+
|
|
15
16
|
import re
|
|
16
|
-
|
|
17
|
-
from msprobe.core.compare.acc_compare import
|
|
17
|
+
|
|
18
|
+
from msprobe.core.compare.acc_compare import ModeConfig
|
|
19
|
+
from msprobe.core.compare.multiprocessing_compute import CompareRealData
|
|
20
|
+
from msprobe.core.compare.utils import read_op, merge_tensor, get_accuracy, make_result_table
|
|
18
21
|
from msprobe.core.common.utils import set_dump_path, get_dump_mode
|
|
19
22
|
from msprobe.visualization.utils import GraphConst
|
|
20
23
|
from msprobe.core.common.const import Const
|
|
21
|
-
|
|
24
|
+
|
|
22
25
|
|
|
23
26
|
# 用于将节点名字解析成对应的NodeOp的规则
|
|
24
27
|
op_patterns = [
|
|
@@ -54,13 +57,11 @@ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
|
|
|
54
57
|
mode_config = ModeConfig(stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.ALL)
|
|
55
58
|
|
|
56
59
|
if framework == Const.PT_FRAMEWORK:
|
|
57
|
-
from msprobe.pytorch.compare.pt_compare import
|
|
58
|
-
return
|
|
60
|
+
from msprobe.pytorch.compare.pt_compare import read_real_data
|
|
61
|
+
return CompareRealData(read_real_data, mode_config, is_cross_frame).do_multi_process(dump_path_param, csv_path)
|
|
59
62
|
else:
|
|
60
|
-
from msprobe.mindspore.compare.ms_compare import
|
|
61
|
-
|
|
62
|
-
ms_comparator.cross_frame = is_cross_frame
|
|
63
|
-
return ms_comparator.do_multi_process(dump_path_param, csv_path)
|
|
63
|
+
from msprobe.mindspore.compare.ms_compare import read_real_data
|
|
64
|
+
return CompareRealData(read_real_data, mode_config, is_cross_frame).do_multi_process(dump_path_param, csv_path)
|
|
64
65
|
|
|
65
66
|
|
|
66
67
|
def get_input_output(node_data, node_id):
|
|
@@ -120,11 +121,13 @@ def compare_data_fuzzy(data_dict_list1, data_dict_list2):
|
|
|
120
121
|
return True
|
|
121
122
|
|
|
122
123
|
|
|
123
|
-
def format_node_data(data_dict, node_id=None):
|
|
124
|
+
def format_node_data(data_dict, node_id=None, compare_mode=None):
|
|
124
125
|
"""
|
|
125
126
|
删除节点数据中不需要展示的字段
|
|
126
127
|
"""
|
|
127
128
|
del_list = ['requires_grad', 'full_op_name']
|
|
129
|
+
if GraphConst.MD5_COMPARE != compare_mode:
|
|
130
|
+
del_list.append(Const.MD5)
|
|
128
131
|
if node_id and GraphConst.BATCH_P2P in node_id:
|
|
129
132
|
del_list.extend(['op', 'peer', 'tag', 'group_id'])
|
|
130
133
|
for _, value in data_dict.items():
|
|
@@ -172,7 +175,7 @@ def _format_decimal_string(s):
|
|
|
172
175
|
"""
|
|
173
176
|
使用正则表达式匹配包含数字、小数点和可选的百分号的字符串
|
|
174
177
|
"""
|
|
175
|
-
pattern = re.compile(r'
|
|
178
|
+
pattern = re.compile(r'^\d{1,20}\.\d{1,20}%?$')
|
|
176
179
|
matches = pattern.findall(s)
|
|
177
180
|
for match in matches:
|
|
178
181
|
is_percent = match.endswith('%')
|
|
@@ -227,3 +230,12 @@ def _format_data(data_dict):
|
|
|
227
230
|
if all_null:
|
|
228
231
|
data_dict.clear()
|
|
229
232
|
data_dict[GraphConst.VALUE] = GraphConst.NULL
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def get_csv_df(stack_mode, csv_data, compare_mode):
|
|
236
|
+
"""
|
|
237
|
+
调用acc接口写入csv
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
|
|
241
|
+
return make_result_table(csv_data, dump_mode, stack_mode)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024, Huawei Technologies Co., Ltd.
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -14,23 +14,27 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import re
|
|
17
|
-
from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data
|
|
18
|
-
from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file
|
|
17
|
+
from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data, get_csv_df
|
|
18
|
+
from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file
|
|
19
19
|
from msprobe.visualization.graph.graph import Graph, NodeOp
|
|
20
|
-
from msprobe.visualization.graph.node_colors import NodeColors
|
|
21
20
|
from msprobe.visualization.compare.mode_adapter import ModeAdapter
|
|
22
21
|
from msprobe.core.common.const import Const
|
|
22
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class GraphComparator:
|
|
26
|
-
|
|
26
|
+
MAX_DEPTH = 1000
|
|
27
|
+
|
|
28
|
+
def __init__(self, graphs, dump_path_param, args, is_cross_framework, mapping_dict=None):
|
|
27
29
|
self.graph_n = graphs[0]
|
|
28
30
|
self.graph_b = graphs[1]
|
|
29
31
|
self._parse_param(dump_path_param, args.output_path)
|
|
30
32
|
self.framework = args.framework
|
|
33
|
+
self.layer_mapping = args.layer_mapping
|
|
31
34
|
self.mapping_dict = mapping_dict
|
|
32
35
|
self.fuzzy_match = args.fuzzy_match
|
|
33
36
|
self.pattern = re.compile(r'\.\d+\.')
|
|
37
|
+
self.is_cross_framework = is_cross_framework
|
|
34
38
|
|
|
35
39
|
def compare(self):
|
|
36
40
|
"""
|
|
@@ -41,7 +45,7 @@ class GraphComparator:
|
|
|
41
45
|
else:
|
|
42
46
|
self._compare_nodes(self.graph_n.root)
|
|
43
47
|
self._postcompare()
|
|
44
|
-
|
|
48
|
+
|
|
45
49
|
def add_compare_result_to_node(self, node, compare_result_list):
|
|
46
50
|
"""
|
|
47
51
|
将比对结果添加到节点的输入输出数据中
|
|
@@ -66,7 +70,58 @@ class GraphComparator:
|
|
|
66
70
|
self.ma.parse_result(node, [compare_in_dict, compare_out_dict]))
|
|
67
71
|
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
68
72
|
node.data.update(other_dict)
|
|
69
|
-
|
|
73
|
+
|
|
74
|
+
def _compare_nodes(self, node_root):
|
|
75
|
+
"""
|
|
76
|
+
遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比
|
|
77
|
+
这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息
|
|
78
|
+
"""
|
|
79
|
+
def compare_single_node(node_n):
|
|
80
|
+
if self.layer_mapping:
|
|
81
|
+
node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, self.mapping_dict)
|
|
82
|
+
if node_b:
|
|
83
|
+
ancestors_n.append(node_n.id)
|
|
84
|
+
ancestors_b.append(node_b.id)
|
|
85
|
+
node_n.matched_node_link = ancestors_b
|
|
86
|
+
node_b.matched_node_link = ancestors_n
|
|
87
|
+
else:
|
|
88
|
+
node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b)
|
|
89
|
+
if node_b:
|
|
90
|
+
ancestors.append(node_b.id)
|
|
91
|
+
node_n.add_link(node_b, ancestors)
|
|
92
|
+
if node_b:
|
|
93
|
+
# 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口
|
|
94
|
+
self._get_and_add_result(node_n, node_b)
|
|
95
|
+
node_list.extend(node_n.subnodes)
|
|
96
|
+
|
|
97
|
+
node_list = [node_root]
|
|
98
|
+
while node_list:
|
|
99
|
+
compare_single_node(node_list.pop(0))
|
|
100
|
+
|
|
101
|
+
def _compare_nodes_fuzzy(self, node_root):
|
|
102
|
+
def compare_single_nodes_fuzzy(node_n):
|
|
103
|
+
if node_n.op != NodeOp.function_api:
|
|
104
|
+
# 模块经过模糊匹配
|
|
105
|
+
node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id))
|
|
106
|
+
if node_b:
|
|
107
|
+
self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b)
|
|
108
|
+
# 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配
|
|
109
|
+
recount_result_n = self._recount_api_node(node_n)
|
|
110
|
+
recount_result_b = self._recount_api_node(node_b)
|
|
111
|
+
for recount_node_id, node_id_n in recount_result_n.items():
|
|
112
|
+
api_node_n = self.graph_n.node_map.get(node_id_n)
|
|
113
|
+
if not api_node_n:
|
|
114
|
+
continue
|
|
115
|
+
api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(
|
|
116
|
+
api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)))
|
|
117
|
+
if api_node_b:
|
|
118
|
+
self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b)
|
|
119
|
+
node_list.extend(node_n.subnodes)
|
|
120
|
+
|
|
121
|
+
node_list = [node_root]
|
|
122
|
+
while node_list:
|
|
123
|
+
compare_single_nodes_fuzzy(node_list.pop(0))
|
|
124
|
+
|
|
70
125
|
def _parse_param(self, dump_path_param, output_path):
|
|
71
126
|
self.dump_path_param = dump_path_param
|
|
72
127
|
self.output_path = output_path
|
|
@@ -81,7 +136,7 @@ class GraphComparator:
|
|
|
81
136
|
if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE:
|
|
82
137
|
return
|
|
83
138
|
df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode)
|
|
84
|
-
df = run_real_data(self.dump_path_param, df, self.framework,
|
|
139
|
+
df = run_real_data(self.dump_path_param, df, self.framework, self.is_cross_framework)
|
|
85
140
|
compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()}
|
|
86
141
|
for node in self.ma.compare_nodes:
|
|
87
142
|
precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
|
|
@@ -103,49 +158,6 @@ class GraphComparator:
|
|
|
103
158
|
else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
|
|
104
159
|
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
105
160
|
|
|
106
|
-
def _compare_nodes(self, node_n):
|
|
107
|
-
"""
|
|
108
|
-
递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比
|
|
109
|
-
这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息
|
|
110
|
-
"""
|
|
111
|
-
if self.mapping_dict:
|
|
112
|
-
node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, self.mapping_dict)
|
|
113
|
-
if node_b:
|
|
114
|
-
ancestors_n.append(node_n.id)
|
|
115
|
-
ancestors_b.append(node_b.id)
|
|
116
|
-
node_n.matched_node_link = ancestors_b
|
|
117
|
-
node_b.matched_node_link = ancestors_n
|
|
118
|
-
else:
|
|
119
|
-
node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b)
|
|
120
|
-
if node_b:
|
|
121
|
-
ancestors.append(node_b.id)
|
|
122
|
-
node_n.add_link(node_b, ancestors)
|
|
123
|
-
if node_b:
|
|
124
|
-
# 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口
|
|
125
|
-
self._get_and_add_result(node_n, node_b)
|
|
126
|
-
for subnode in node_n.subnodes:
|
|
127
|
-
self._compare_nodes(subnode)
|
|
128
|
-
|
|
129
|
-
def _compare_nodes_fuzzy(self, node_n):
|
|
130
|
-
if node_n.op != NodeOp.function_api:
|
|
131
|
-
# 模块经过模糊匹配
|
|
132
|
-
node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id))
|
|
133
|
-
if node_b:
|
|
134
|
-
self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b)
|
|
135
|
-
# 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配
|
|
136
|
-
recount_result_n = self._recount_api_node(node_n)
|
|
137
|
-
recount_result_b = self._recount_api_node(node_b)
|
|
138
|
-
for recount_node_id, node_id_n in recount_result_n.items():
|
|
139
|
-
api_node_n = self.graph_n.node_map.get(node_id_n)
|
|
140
|
-
if not api_node_n:
|
|
141
|
-
continue
|
|
142
|
-
api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(
|
|
143
|
-
api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)))
|
|
144
|
-
if api_node_b:
|
|
145
|
-
self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b)
|
|
146
|
-
for sub_node in node_n.subnodes:
|
|
147
|
-
self._compare_nodes_fuzzy(sub_node)
|
|
148
|
-
|
|
149
161
|
def _get_and_add_result(self, node_n, node_b):
|
|
150
162
|
compare_result_list = compare_node([node_n.id, node_b.id],
|
|
151
163
|
[self.data_n_dict, self.data_b_dict],
|
|
@@ -13,8 +13,8 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import json
|
|
17
16
|
import math
|
|
17
|
+
import json
|
|
18
18
|
from msprobe.core.common.const import CompareConst, Const
|
|
19
19
|
from msprobe.visualization.utils import ToolTip, GraphConst, str2float
|
|
20
20
|
|
|
@@ -25,6 +25,12 @@ class ModeAdapter:
|
|
|
25
25
|
self.csv_data = []
|
|
26
26
|
self.compare_nodes = []
|
|
27
27
|
|
|
28
|
+
@staticmethod
|
|
29
|
+
def _is_invalid(value):
|
|
30
|
+
if not isinstance(value, float):
|
|
31
|
+
return False
|
|
32
|
+
return math.isnan(value) or math.isinf(value)
|
|
33
|
+
|
|
28
34
|
@staticmethod
|
|
29
35
|
def _add_md5_compare_data(node_data, compare_data_dict):
|
|
30
36
|
precision_index = GraphConst.MAX_INDEX_KEY
|
|
@@ -49,6 +55,8 @@ class ModeAdapter:
|
|
|
49
55
|
for key, value in node_data.items():
|
|
50
56
|
if not isinstance(value, dict):
|
|
51
57
|
continue
|
|
58
|
+
if value.get(Const.MAX) is None:
|
|
59
|
+
continue
|
|
52
60
|
compare_data = compare_data_dict.get(key)
|
|
53
61
|
if compare_data:
|
|
54
62
|
headers = CompareConst.COMPARE_RESULT_HEADER
|
|
@@ -67,9 +75,13 @@ class ModeAdapter:
|
|
|
67
75
|
if thousandth is not None:
|
|
68
76
|
numbers.append(thousandth)
|
|
69
77
|
node_data[key] = value
|
|
78
|
+
if ModeAdapter._is_invalid(value.get(Const.MAX)) or ModeAdapter._is_invalid(value.get(Const.MIN)):
|
|
79
|
+
numbers.append(CompareConst.N_A)
|
|
70
80
|
# 双千指标都是None的异常情况
|
|
71
81
|
if not numbers:
|
|
72
82
|
min_thousandth = None
|
|
83
|
+
elif CompareConst.N_A in numbers:
|
|
84
|
+
min_thousandth = CompareConst.N_A
|
|
73
85
|
else:
|
|
74
86
|
min_thousandth = min(numbers + [min_thousandth])
|
|
75
87
|
return min_thousandth
|
|
@@ -81,6 +93,8 @@ class ModeAdapter:
|
|
|
81
93
|
for key, data_info in node_data.items():
|
|
82
94
|
if not isinstance(data_info, dict):
|
|
83
95
|
continue
|
|
96
|
+
if data_info.get(Const.MAX) is None:
|
|
97
|
+
continue
|
|
84
98
|
compare_data = compare_data_dict.get(key)
|
|
85
99
|
if compare_data:
|
|
86
100
|
# 对应比对结果csv的列
|
|
@@ -92,6 +106,8 @@ class ModeAdapter:
|
|
|
92
106
|
relative_err = str2float(data_info.get(item))
|
|
93
107
|
max_relative_err = max(max_relative_err, relative_err)
|
|
94
108
|
node_data[key] = data_info
|
|
109
|
+
if ModeAdapter._is_invalid(data_info.get(Const.MAX)) or ModeAdapter._is_invalid(data_info.get(Const.MIN)):
|
|
110
|
+
max_relative_err = GraphConst.MAX_INDEX_KEY
|
|
95
111
|
max_relative_err = 1 if max_relative_err > 1 else max_relative_err
|
|
96
112
|
return max_relative_err
|
|
97
113
|
|
|
@@ -133,7 +149,11 @@ class ModeAdapter:
|
|
|
133
149
|
ModeAdapter._check_list_len(compare_data_dict_list, 1)
|
|
134
150
|
min_thousandth_in = ModeAdapter._add_real_compare_data(node.input_data, compare_data_dict_list[0])
|
|
135
151
|
min_thousandth_out = ModeAdapter._add_real_compare_data(node.output_data, compare_data_dict_list[0])
|
|
136
|
-
if
|
|
152
|
+
if CompareConst.N_A == min_thousandth_out:
|
|
153
|
+
change_percentage = GraphConst.MAX_INDEX_KEY
|
|
154
|
+
elif CompareConst.N_A == min_thousandth_in:
|
|
155
|
+
change_percentage = GraphConst.MIN_INDEX_KEY
|
|
156
|
+
elif min_thousandth_in is not None and min_thousandth_out is not None:
|
|
137
157
|
change_percentage = min_thousandth_in - min_thousandth_out
|
|
138
158
|
else:
|
|
139
159
|
change_percentage = GraphConst.MIN_INDEX_KEY
|
|
@@ -157,24 +177,6 @@ class ModeAdapter:
|
|
|
157
177
|
return
|
|
158
178
|
self.csv_data.extend(compare_result_list)
|
|
159
179
|
|
|
160
|
-
def add_error_key(self, node_data):
|
|
161
|
-
"""
|
|
162
|
-
根据不同的模式进行提供不同错误信息
|
|
163
|
-
"""
|
|
164
|
-
for key, value in node_data.items():
|
|
165
|
-
if not isinstance(value, dict):
|
|
166
|
-
continue
|
|
167
|
-
if self.compare_mode == GraphConst.SUMMARY_COMPARE:
|
|
168
|
-
message = [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR,
|
|
169
|
-
CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]
|
|
170
|
-
elif self.compare_mode == GraphConst.REAL_DATA_COMPARE:
|
|
171
|
-
message = [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]
|
|
172
|
-
else:
|
|
173
|
-
# 输出件优化
|
|
174
|
-
message = []
|
|
175
|
-
value[GraphConst.ERROR_KEY] = message
|
|
176
|
-
node_data[key] = value
|
|
177
|
-
|
|
178
180
|
def get_tool_tip(self):
|
|
179
181
|
"""
|
|
180
182
|
用于前端展示字段的具体含义
|
|
@@ -12,10 +12,11 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
+
|
|
15
16
|
from msprobe.core.overflow_check.level import OverflowLevel
|
|
16
|
-
from msprobe.visualization.graph.node_op import NodeOp
|
|
17
17
|
from msprobe.visualization.utils import GraphConst
|
|
18
18
|
from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_data_fuzzy
|
|
19
|
+
from msprobe.core.common.log import logger
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class BaseNode:
|
|
@@ -86,15 +87,15 @@ class BaseNode:
|
|
|
86
87
|
self.matched_node_link = ancestors
|
|
87
88
|
node.matched_node_link = ancestors
|
|
88
89
|
|
|
89
|
-
def to_dict(self):
|
|
90
|
+
def to_dict(self, compare_mode=None):
|
|
90
91
|
"""
|
|
91
92
|
输出数据
|
|
92
93
|
"""
|
|
93
94
|
result = {
|
|
94
95
|
'id': self.id,
|
|
95
96
|
'node_type': self.op.value,
|
|
96
|
-
'output_data': format_node_data(self.output_data, self.id),
|
|
97
|
-
'input_data': format_node_data(self.input_data, self.id),
|
|
97
|
+
'output_data': format_node_data(self.output_data, self.id, compare_mode),
|
|
98
|
+
'input_data': format_node_data(self.input_data, self.id, compare_mode),
|
|
98
99
|
'upnode': self.upnode.id if self.upnode else 'None',
|
|
99
100
|
'subnodes': [node.id for node in self.subnodes],
|
|
100
101
|
'matched_node_link': self.matched_node_link,
|
|
@@ -114,7 +115,13 @@ class BaseNode:
|
|
|
114
115
|
"""
|
|
115
116
|
ancestors = []
|
|
116
117
|
current_node = self.upnode
|
|
118
|
+
seen_nodes = set()
|
|
117
119
|
while current_node:
|
|
120
|
+
if current_node.id in seen_nodes:
|
|
121
|
+
logger.warning(f'Detected a cycle in the node structure and cannot get node ancestors, '
|
|
122
|
+
f'current node is {current_node.id}.')
|
|
123
|
+
return []
|
|
124
|
+
seen_nodes.add(current_node.id)
|
|
118
125
|
ancestors.append(current_node.id)
|
|
119
126
|
current_node = current_node.upnode
|
|
120
127
|
return list(reversed(ancestors))
|
|
@@ -107,15 +107,6 @@ class DistributedAnalyzer:
|
|
|
107
107
|
return None, None
|
|
108
108
|
return group_ranks, group_id
|
|
109
109
|
|
|
110
|
-
@staticmethod
|
|
111
|
-
def _get_batch_group_info(node, rank):
|
|
112
|
-
for data in node.input_data.values():
|
|
113
|
-
group_id = data.get('group_id')
|
|
114
|
-
if group_id is not None:
|
|
115
|
-
return group_id
|
|
116
|
-
logger.warning(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
|
|
117
|
-
return None
|
|
118
|
-
|
|
119
110
|
def distributed_match(self):
|
|
120
111
|
for rank, graph in self.graphs.items():
|
|
121
112
|
nodes = graph.node_map
|
|
@@ -377,7 +368,7 @@ class DistributedAnalyzer:
|
|
|
377
368
|
target_api_name = self.config.get(api_name)[0]
|
|
378
369
|
target_rank = int(id_info[1].replace(Const.RANK, ''))
|
|
379
370
|
except Exception as e:
|
|
380
|
-
logger.warning(f'Failed to
|
|
371
|
+
logger.warning(f'Failed to parse batch p2p parameter with error info: {e}.')
|
|
381
372
|
continue
|
|
382
373
|
target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank, target_api_name)
|
|
383
374
|
if not target_node:
|
|
@@ -20,9 +20,6 @@ from msprobe.core.common.log import logger
|
|
|
20
20
|
from msprobe.core.common.const import Const
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
MAX_RECUR_LEVEL = 100
|
|
24
|
-
|
|
25
|
-
|
|
26
23
|
class Graph:
|
|
27
24
|
def __init__(self, model_name, data_path='', dump_data=None):
|
|
28
25
|
self.node_map = {}
|
|
@@ -67,7 +64,6 @@ class Graph:
|
|
|
67
64
|
ancestors_b = node_b.get_ancestors()
|
|
68
65
|
return node_b, ancestors_n, ancestors_b
|
|
69
66
|
|
|
70
|
-
|
|
71
67
|
@staticmethod
|
|
72
68
|
def fuzzy_match(node_n, node_b):
|
|
73
69
|
if not node_n or not node_b or not node_n.fuzzy_eq(node_b):
|
|
@@ -76,13 +72,6 @@ class Graph:
|
|
|
76
72
|
ancestors_b = node_b.get_ancestors()
|
|
77
73
|
return node_b, ancestors_n, ancestors_b
|
|
78
74
|
|
|
79
|
-
@staticmethod
|
|
80
|
-
def dfs(node, result):
|
|
81
|
-
info = node.to_dict()
|
|
82
|
-
result[node.id] = info
|
|
83
|
-
for subnode in node.subnodes:
|
|
84
|
-
Graph.dfs(subnode, result)
|
|
85
|
-
|
|
86
75
|
@staticmethod
|
|
87
76
|
def split_nodes_by_micro_step(nodes):
|
|
88
77
|
"""
|
|
@@ -157,7 +146,7 @@ class Graph:
|
|
|
157
146
|
"""
|
|
158
147
|
return self.node_map.get(node_id, None)
|
|
159
148
|
|
|
160
|
-
def to_dict(self):
|
|
149
|
+
def to_dict(self, compare_mode=None):
|
|
161
150
|
"""
|
|
162
151
|
用于数据输出
|
|
163
152
|
"""
|
|
@@ -166,7 +155,7 @@ class Graph:
|
|
|
166
155
|
result[GraphConst.JSON_DATA_KEY] = self.data_path
|
|
167
156
|
result[GraphConst.JSON_NODE_KEY] = {}
|
|
168
157
|
for node_id in self.node_map:
|
|
169
|
-
info = self.node_map.get(node_id).to_dict()
|
|
158
|
+
info = self.node_map.get(node_id).to_dict(compare_mode)
|
|
170
159
|
result[GraphConst.JSON_NODE_KEY][node_id] = info
|
|
171
160
|
return result
|
|
172
161
|
|
|
@@ -24,7 +24,6 @@ class NodeOp(Enum):
|
|
|
24
24
|
function_api = 1
|
|
25
25
|
api_collection = 9
|
|
26
26
|
|
|
27
|
-
|
|
28
27
|
@staticmethod
|
|
29
28
|
def get_node_op(node_name: str):
|
|
30
29
|
"""
|
|
@@ -37,5 +36,5 @@ class NodeOp(Enum):
|
|
|
37
36
|
pattern = op_patterns[index]
|
|
38
37
|
if re.match(pattern, node_name):
|
|
39
38
|
return op
|
|
40
|
-
logger.warning(f"Cannot
|
|
39
|
+
logger.warning(f"Cannot parse node_name {node_name} into NodeOp, default parsing as module.")
|
|
41
40
|
return NodeOp.module
|