mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# 标准库导入
|
|
17
|
+
import multiprocessing
|
|
18
|
+
from multiprocessing import Manager
|
|
19
|
+
import os
|
|
20
|
+
import signal
|
|
21
|
+
import sys
|
|
22
|
+
import time
|
|
23
|
+
|
|
24
|
+
# 第三方库导入
|
|
25
|
+
from mindspore import context
|
|
26
|
+
import numpy as np
|
|
27
|
+
from tqdm import tqdm
|
|
28
|
+
|
|
29
|
+
# 本地应用/库特定导入
|
|
30
|
+
from msprobe.core.common.const import Const, CompareConst, MsCompareConst
|
|
31
|
+
from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker, BasicInfoAndStatus
|
|
32
|
+
from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataManager
|
|
33
|
+
from msprobe.mindspore.common.log import logger
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class MultiApiAccuracyChecker(ApiAccuracyChecker):
|
|
37
|
+
def __init__(self, args):
|
|
38
|
+
# 可以添加 MultiApiAccuracyChecker 特有的属性或方法
|
|
39
|
+
self.api_infos = dict()
|
|
40
|
+
|
|
41
|
+
# 使用 Manager 创建共享变量,确保进程间的同步
|
|
42
|
+
self.manager = Manager()
|
|
43
|
+
self.is_first_write = self.manager.Value('b', True) # 创建共享变量
|
|
44
|
+
|
|
45
|
+
# 初始化 DataManager 时传入共享的 is_first_write
|
|
46
|
+
self.multi_data_manager = MultiDataManager(args.out_path, args.result_csv_path, self.is_first_write)
|
|
47
|
+
|
|
48
|
+
self.args = args # 将 args 保存为类的属性
|
|
49
|
+
|
|
50
|
+
# 初始化一个属性来存储当前的设备ID(用于日志中显示)
|
|
51
|
+
self.current_device_id = None
|
|
52
|
+
|
|
53
|
+
def process_on_device(self, device_id, api_infos, progress_queue):
|
|
54
|
+
"""
|
|
55
|
+
在特定设备上处理一部分API。
|
|
56
|
+
|
|
57
|
+
参数:
|
|
58
|
+
device_id (int): 要使用的设备ID。
|
|
59
|
+
api_infos (list): 包含API名称和对应信息的元组列表。
|
|
60
|
+
progress_queue (multiprocessing.Queue): 用于通信进度更新的队列。
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
# 设置当前设备ID
|
|
64
|
+
self.current_device_id = device_id
|
|
65
|
+
|
|
66
|
+
# 设置 MindSpore context 的 device_id
|
|
67
|
+
context.set_context(device_id=device_id)
|
|
68
|
+
|
|
69
|
+
# 遍历当前进程分配的任务
|
|
70
|
+
for _, (api_name_str, api_info) in enumerate(api_infos):
|
|
71
|
+
logger.debug(f"Processing API: {api_name_str}, Device: {device_id}")
|
|
72
|
+
|
|
73
|
+
if not self.multi_data_manager.is_unique_api(api_name_str):
|
|
74
|
+
logger.debug(f"API {api_name_str} is not unique, skipping.")
|
|
75
|
+
progress_queue.put(1)
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
# 处理前向
|
|
79
|
+
forward_output_list = self.process_forward(api_name_str, api_info)
|
|
80
|
+
if forward_output_list is not Const.EXCEPTION_NONE:
|
|
81
|
+
self.multi_data_manager.record(forward_output_list)
|
|
82
|
+
|
|
83
|
+
# 处理反向
|
|
84
|
+
backward_output_list = self.process_backward(api_name_str, api_info)
|
|
85
|
+
if backward_output_list is not Const.EXCEPTION_NONE:
|
|
86
|
+
self.multi_data_manager.record(backward_output_list)
|
|
87
|
+
|
|
88
|
+
# 保存结果
|
|
89
|
+
self.multi_data_manager.save_results(api_name_str)
|
|
90
|
+
progress_queue.put(1) # 更新进度
|
|
91
|
+
|
|
92
|
+
def run_and_compare(self):
|
|
93
|
+
# 获取要使用的设备ID列表
|
|
94
|
+
device_ids = self.args.device_id
|
|
95
|
+
|
|
96
|
+
# 按设备数划分要处理的 API 项
|
|
97
|
+
partitioned_api_infos = list(self.api_infos.items())
|
|
98
|
+
|
|
99
|
+
# 在主进程中进行交叉任务切分(基于取模的方式)
|
|
100
|
+
partitioned_api_infos_split = [[] for _ in range(len(device_ids))]
|
|
101
|
+
for idx, api_info in enumerate(partitioned_api_infos):
|
|
102
|
+
device_index = idx % len(device_ids) # 使用取模方法分配任务
|
|
103
|
+
partitioned_api_infos_split[device_index].append(api_info)
|
|
104
|
+
|
|
105
|
+
# 创建一个共享进度队列
|
|
106
|
+
progress_queue = multiprocessing.Queue()
|
|
107
|
+
|
|
108
|
+
# 进度条
|
|
109
|
+
total_tasks = len(partitioned_api_infos) # 计算总任务数
|
|
110
|
+
with tqdm(total=total_tasks, desc="Total Progress", ncols=100) as pbar:
|
|
111
|
+
# 创建多进程
|
|
112
|
+
processes = []
|
|
113
|
+
for index, device_id in enumerate(device_ids):
|
|
114
|
+
process = multiprocessing.Process(target=self.process_on_device,
|
|
115
|
+
args=(device_id, partitioned_api_infos_split[index], progress_queue))
|
|
116
|
+
processes.append(process)
|
|
117
|
+
process.start()
|
|
118
|
+
|
|
119
|
+
# 主进程更新进度条
|
|
120
|
+
completed_tasks = 0
|
|
121
|
+
while completed_tasks < total_tasks:
|
|
122
|
+
try:
|
|
123
|
+
completed_tasks += progress_queue.get(timeout=Const.PROGRESS_TIMEOUT) # 设置超时时间(秒)
|
|
124
|
+
pbar.update(1)
|
|
125
|
+
except multiprocessing.queues.Empty:
|
|
126
|
+
logger.error("Timeout while waiting for progress updates. Skipping remaining tasks.")
|
|
127
|
+
break
|
|
128
|
+
|
|
129
|
+
# 检查子进程状态
|
|
130
|
+
for process in processes:
|
|
131
|
+
if not process.is_alive():
|
|
132
|
+
if process.exitcode != 0:
|
|
133
|
+
logger.error(f"Process {process.pid} exited with code {process.exitcode}.")
|
|
134
|
+
total_tasks -= len(partitioned_api_infos_split[processes.index(process)])
|
|
135
|
+
processes.remove(process)
|
|
136
|
+
|
|
137
|
+
# 确保所有子进程完成或终止
|
|
138
|
+
for process in processes:
|
|
139
|
+
process.join(timeout=Const.PROGRESS_TIMEOUT)
|
|
140
|
+
if process.is_alive():
|
|
141
|
+
logger.error(f"Process {process.pid} did not terminate. Forcing termination.")
|
|
142
|
+
process.terminate()
|
|
143
|
+
|
|
144
|
+
def process_forward(self, api_name_str, api_info):
|
|
145
|
+
"""
|
|
146
|
+
Overrides the parent class's process_forward method to log the device ID when exceptions occur.
|
|
147
|
+
|
|
148
|
+
Parameters:
|
|
149
|
+
api_name_str (str): The name of the API.
|
|
150
|
+
api_info (object): The API information object.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
list or None: The forward output list or None if an error occurs.
|
|
154
|
+
"""
|
|
155
|
+
if not api_info.check_forward_info():
|
|
156
|
+
logger.debug(
|
|
157
|
+
f"[Device {self.current_device_id}] API: {api_name_str} lacks forward information, skipping "
|
|
158
|
+
f"forward check.")
|
|
159
|
+
return Const.EXCEPTION_NONE
|
|
160
|
+
|
|
161
|
+
try:
|
|
162
|
+
forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
|
|
163
|
+
except Exception as e:
|
|
164
|
+
logger.warning(
|
|
165
|
+
f"[Device {self.current_device_id}] Exception occurred while getting forward API inputs for "
|
|
166
|
+
f"{api_name_str}. Skipping forward check. Detailed exception information: {e}.")
|
|
167
|
+
return Const.EXCEPTION_NONE
|
|
168
|
+
|
|
169
|
+
forward_output_list = None
|
|
170
|
+
try:
|
|
171
|
+
forward_output_list = self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation,
|
|
172
|
+
Const.FORWARD)
|
|
173
|
+
except Exception as e:
|
|
174
|
+
logger.warning(
|
|
175
|
+
f"[Device {self.current_device_id}] Exception occurred while running and comparing {api_name_str} "
|
|
176
|
+
f"forward API. Detailed exception information: {e}.")
|
|
177
|
+
return forward_output_list
|
|
178
|
+
|
|
179
|
+
def process_backward(self, api_name_str, api_info):
|
|
180
|
+
"""
|
|
181
|
+
Overrides the parent class's process_backward method to log the device ID when exceptions occur.
|
|
182
|
+
|
|
183
|
+
Parameters:
|
|
184
|
+
api_name_str (str): The name of the API.
|
|
185
|
+
api_info (object): The API information object.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
list or None: The backward output list or None if an error occurs.
|
|
189
|
+
"""
|
|
190
|
+
if not api_info.check_backward_info():
|
|
191
|
+
logger.debug(
|
|
192
|
+
f"[Device {self.current_device_id}] API: {api_name_str} lacks backward information, skipping "
|
|
193
|
+
f"backward check.")
|
|
194
|
+
return Const.EXCEPTION_NONE
|
|
195
|
+
|
|
196
|
+
try:
|
|
197
|
+
backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
|
|
198
|
+
except Exception as e:
|
|
199
|
+
logger.warning(
|
|
200
|
+
f"[Device {self.current_device_id}] Exception occurred while getting backward API inputs for "
|
|
201
|
+
f"{api_name_str}. Skipping backward check. Detailed exception information: {e}.")
|
|
202
|
+
return Const.EXCEPTION_NONE
|
|
203
|
+
|
|
204
|
+
backward_output_list = None
|
|
205
|
+
try:
|
|
206
|
+
backward_output_list = self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation,
|
|
207
|
+
Const.BACKWARD)
|
|
208
|
+
except Exception as e:
|
|
209
|
+
logger.warning(
|
|
210
|
+
f"[Device {self.current_device_id}] Exception occurred while running and comparing {api_name_str} "
|
|
211
|
+
f"backward API. Detailed exception information: {e}.")
|
|
212
|
+
return backward_output_list
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
import multiprocessing
|
|
18
|
+
import os
|
|
19
|
+
|
|
20
|
+
from msprobe.mindspore.api_accuracy_checker.data_manager import (DataManager, ResultCsvEntry, write_csv_header,
|
|
21
|
+
get_result_csv_header, get_detail_csv_header,
|
|
22
|
+
check_csv_header)
|
|
23
|
+
from msprobe.mindspore.common.log import logger
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MultiDataManager(DataManager):
|
|
27
|
+
def __init__(self, csv_dir, result_csv_path, shared_is_first_write):
|
|
28
|
+
super().__init__(csv_dir, result_csv_path)
|
|
29
|
+
|
|
30
|
+
# 使用共享的 is_first_write 变量来控制表头写入
|
|
31
|
+
self.shared_is_first_write = shared_is_first_write
|
|
32
|
+
# 创建锁对象,确保线程安全
|
|
33
|
+
self.lock = multiprocessing.Lock()
|
|
34
|
+
|
|
35
|
+
def save_results(self, api_name_str):
|
|
36
|
+
"""保存结果,线程安全操作"""
|
|
37
|
+
|
|
38
|
+
with self.lock: # 确保保存操作不会被多个进程同时进行
|
|
39
|
+
if self.is_first_write and self.shared_is_first_write.value:
|
|
40
|
+
self.shared_is_first_write.value = False
|
|
41
|
+
self.is_first_write = False # 写入后标记为 False,避免重复写入表头
|
|
42
|
+
# 直接写入表头
|
|
43
|
+
logger.info("Writing CSV headers for the first time.")
|
|
44
|
+
write_csv_header(self.detail_out_path, get_detail_csv_header)
|
|
45
|
+
write_csv_header(self.result_out_path, get_result_csv_header)
|
|
46
|
+
|
|
47
|
+
"""写入详细输出和结果摘要并清理结果"""
|
|
48
|
+
self.to_detail_csv(self.detail_out_path)
|
|
49
|
+
logger.debug(f"Detailed output for {api_name_str} written to {self.detail_out_path}.")
|
|
50
|
+
|
|
51
|
+
self.to_result_csv(self.result_out_path)
|
|
52
|
+
logger.debug(f"Result summary for {api_name_str} written to {self.result_out_path}.")
|
|
53
|
+
|
|
54
|
+
# 清理记录,准备下一次调用
|
|
55
|
+
self.clear_results()
|
|
56
|
+
|
|
57
|
+
def clear_results(self):
|
|
58
|
+
"""清空 self.results 数据,线程安全操作"""
|
|
59
|
+
logger.debug("Clearing results data.")
|
|
60
|
+
self.results.clear()
|
|
@@ -1,7 +1,23 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
3
16
|
import mindspore
|
|
17
|
+
import numpy as np
|
|
4
18
|
import torch
|
|
19
|
+
from mindspore._c_expression import typing
|
|
20
|
+
from mindspore.common import dtype as mstype
|
|
5
21
|
|
|
6
22
|
INT8 = "Int8"
|
|
7
23
|
UINT8 = "UInt8"
|
|
@@ -18,7 +34,6 @@ BOOL = "Bool"
|
|
|
18
34
|
BFLOAT16 = "BFloat16"
|
|
19
35
|
INT4 = "Int4"
|
|
20
36
|
|
|
21
|
-
|
|
22
37
|
dtype_str_to_ms_dtype = {
|
|
23
38
|
INT8: mstype.int8,
|
|
24
39
|
UINT8: mstype.uint8,
|
|
@@ -37,7 +52,6 @@ dtype_str_to_ms_dtype = {
|
|
|
37
52
|
}
|
|
38
53
|
ms_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_ms_dtype.items()}
|
|
39
54
|
|
|
40
|
-
|
|
41
55
|
dtype_str_to_np_dtype = {
|
|
42
56
|
INT8: np.int8,
|
|
43
57
|
UINT8: np.uint8,
|
|
@@ -75,6 +89,8 @@ FLOAT_TYPE_STR = "float"
|
|
|
75
89
|
SLICE_TYPE_STR = "slice"
|
|
76
90
|
TUPLE_TYPE_STR = "tuple"
|
|
77
91
|
STR_TYPE_STR = "str"
|
|
92
|
+
MINDSPORE_DTYPE_TYPE_STR = "mindspore.dtype"
|
|
93
|
+
TORCH_DTYPE_TYPE_STR = "torch.dtype"
|
|
78
94
|
|
|
79
95
|
api_info_type_str_to_type = {
|
|
80
96
|
MINDSPORE_TENSOR_TYPE_STR: mindspore.Tensor,
|
|
@@ -83,6 +99,7 @@ api_info_type_str_to_type = {
|
|
|
83
99
|
FLOAT_TYPE_STR: float,
|
|
84
100
|
SLICE_TYPE_STR: slice,
|
|
85
101
|
STR_TYPE_STR: str,
|
|
102
|
+
MINDSPORE_DTYPE_TYPE_STR: typing.Type,
|
|
86
103
|
}
|
|
87
104
|
type_to_api_info_type_str = {value: key for key, value in api_info_type_str_to_type.items()}
|
|
88
105
|
|
|
@@ -111,4 +128,4 @@ uint_dtype_str_list = [
|
|
|
111
128
|
UINT16,
|
|
112
129
|
UINT32,
|
|
113
130
|
UINT64,
|
|
114
|
-
]
|
|
131
|
+
]
|
|
@@ -1,8 +1,24 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
2
16
|
from msprobe.core.common.const import Const
|
|
17
|
+
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
3
18
|
from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list
|
|
4
19
|
from msprobe.mindspore.common.log import logger
|
|
5
20
|
|
|
21
|
+
|
|
6
22
|
def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_type=None, accepted_value=None):
|
|
7
23
|
'''
|
|
8
24
|
Args:
|
|
@@ -22,30 +38,30 @@ def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_t
|
|
|
22
38
|
3. value is not accepted type
|
|
23
39
|
4. value is not accepted value
|
|
24
40
|
'''
|
|
25
|
-
parse_failed_exception = ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed)
|
|
26
41
|
if not isinstance(dict_instance, dict):
|
|
27
|
-
|
|
42
|
+
error_info = "check_and_get_from_json_dict failed: input is not a dict"
|
|
43
|
+
raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
|
|
28
44
|
value = dict_instance.get(key)
|
|
29
45
|
if value is None:
|
|
30
|
-
|
|
31
|
-
|
|
46
|
+
error_info = f"check_and_get_from_json_dict failed: {key_description} is missing"
|
|
47
|
+
raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
|
|
32
48
|
elif accepted_type is not None and not isinstance(value, accepted_type):
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
parse_failed_exception)
|
|
49
|
+
error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted type: {accepted_type}"
|
|
50
|
+
raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
|
|
36
51
|
elif accepted_value is not None and value not in accepted_value:
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
parse_failed_exception)
|
|
52
|
+
error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted value: {accepted_value}"
|
|
53
|
+
raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
|
|
40
54
|
return value
|
|
41
55
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
56
|
+
|
|
57
|
+
def convert_to_tuple(args):
|
|
58
|
+
if isinstance(args, (tuple, list)):
|
|
59
|
+
return tuple(args)
|
|
45
60
|
else:
|
|
46
|
-
input_list = [
|
|
61
|
+
input_list = [args]
|
|
47
62
|
return tuple(input_list)
|
|
48
63
|
|
|
64
|
+
|
|
49
65
|
def trim_output_compute_element_list(compute_element_list, forward_or_backward):
|
|
50
66
|
'''
|
|
51
67
|
Args:
|
|
@@ -55,12 +71,13 @@ def trim_output_compute_element_list(compute_element_list, forward_or_backward):
|
|
|
55
71
|
trimmed_list = []
|
|
56
72
|
for compute_element in compute_element_list:
|
|
57
73
|
if compute_element.get_parameter() is None or \
|
|
58
|
-
|
|
74
|
+
(forward_or_backward == Const.BACKWARD and compute_element.get_dtype() not in float_dtype_str_list):
|
|
59
75
|
# trim case: 1. parameter is None. 2. backward output has non float parameter
|
|
60
76
|
continue
|
|
61
77
|
trimmed_list.append(compute_element)
|
|
62
78
|
return trimmed_list
|
|
63
79
|
|
|
80
|
+
|
|
64
81
|
class GlobalContext:
|
|
65
82
|
def __init__(self):
|
|
66
83
|
self.is_constructed = True
|
|
@@ -77,4 +94,4 @@ class GlobalContext:
|
|
|
77
94
|
return self.is_constructed
|
|
78
95
|
|
|
79
96
|
|
|
80
|
-
global_context = GlobalContext()
|
|
97
|
+
global_context = GlobalContext()
|
|
@@ -1,4 +1,19 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope
|
|
2
17
|
from msprobe.core.common.const import Const
|
|
3
18
|
|
|
4
19
|
|
|
@@ -9,10 +24,7 @@ class CellProcessor:
|
|
|
9
24
|
module_node = {}
|
|
10
25
|
|
|
11
26
|
def __init__(self, scope):
|
|
12
|
-
if isinstance(scope, ModuleRangeScope)
|
|
13
|
-
self.scope = scope
|
|
14
|
-
else:
|
|
15
|
-
self.scope = None
|
|
27
|
+
self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
|
|
16
28
|
|
|
17
29
|
@staticmethod
|
|
18
30
|
def set_cell_count(cell_name):
|
|
@@ -21,30 +33,29 @@ class CellProcessor:
|
|
|
21
33
|
else:
|
|
22
34
|
CellProcessor.cell_count[cell_name] += 1
|
|
23
35
|
return CellProcessor.cell_count[cell_name]
|
|
24
|
-
|
|
36
|
+
|
|
25
37
|
@classmethod
|
|
26
38
|
def reset_cell_stats(cls):
|
|
27
39
|
cls.cell_count = {}
|
|
28
40
|
cls.cell_stack = []
|
|
29
41
|
cls.api_parent_node = ""
|
|
30
42
|
cls.module_node = {}
|
|
31
|
-
|
|
43
|
+
|
|
32
44
|
def node_hook(self, name_prefix, start_or_stop, **kwargs):
|
|
33
|
-
def begin_hook(cell,
|
|
34
|
-
|
|
35
|
-
cell.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index)
|
|
45
|
+
def begin_hook(cell, input_data):
|
|
46
|
+
full_name = self.set_and_get_reserved_name(cell, name_prefix, is_called_by_pre_hook=True)
|
|
36
47
|
if CellProcessor.cell_stack:
|
|
37
48
|
CellProcessor.module_node[full_name] = CellProcessor.cell_stack[-1]
|
|
38
49
|
else:
|
|
39
50
|
CellProcessor.module_node[full_name] = None
|
|
40
|
-
|
|
51
|
+
|
|
41
52
|
CellProcessor.cell_stack.append(full_name)
|
|
42
53
|
CellProcessor.api_parent_node = full_name
|
|
43
54
|
|
|
44
55
|
if self.scope:
|
|
45
56
|
self.scope.begin_module(full_name)
|
|
46
57
|
|
|
47
|
-
def end_hook(cell,
|
|
58
|
+
def end_hook(cell, input_data, output_data):
|
|
48
59
|
if CellProcessor.cell_stack:
|
|
49
60
|
CellProcessor.cell_stack.pop()
|
|
50
61
|
if CellProcessor.cell_stack:
|
|
@@ -56,3 +67,13 @@ class CellProcessor:
|
|
|
56
67
|
self.scope.end_module(cell.mindstudio_reserved_name)
|
|
57
68
|
|
|
58
69
|
return begin_hook if Const.START == start_or_stop else end_hook
|
|
70
|
+
|
|
71
|
+
def set_and_get_reserved_name(self, cell, cell_name, is_called_by_pre_hook=False):
|
|
72
|
+
if not is_called_by_pre_hook and hasattr(cell, 'has_pre_hook_called') and cell.has_pre_hook_called:
|
|
73
|
+
cell.has_pre_hook_called = False
|
|
74
|
+
else:
|
|
75
|
+
if is_called_by_pre_hook:
|
|
76
|
+
cell.has_pre_hook_called = True
|
|
77
|
+
index = self.set_cell_count(cell_name)
|
|
78
|
+
cell.mindstudio_reserved_name = cell_name + Const.SEP + str(index)
|
|
79
|
+
return cell.mindstudio_reserved_name
|