mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -33,6 +33,9 @@ from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataM
|
|
|
33
33
|
from msprobe.mindspore.common.log import logger
|
|
34
34
|
from msprobe.mindspore.common.const import MsCompareConst
|
|
35
35
|
|
|
36
|
+
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
37
|
+
from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
|
|
38
|
+
|
|
36
39
|
|
|
37
40
|
class MultiApiAccuracyChecker(ApiAccuracyChecker):
|
|
38
41
|
def __init__(self, args):
|
|
@@ -51,6 +54,12 @@ class MultiApiAccuracyChecker(ApiAccuracyChecker):
|
|
|
51
54
|
# 初始化一个属性来存储当前的设备ID(用于日志中显示)
|
|
52
55
|
self.current_device_id = None
|
|
53
56
|
|
|
57
|
+
self.save_error_data = args.save_error_data
|
|
58
|
+
if self.save_error_data:
|
|
59
|
+
config, dump_path_aggregation = self.init_save_error_data(args)
|
|
60
|
+
self.data_collector = build_data_collector(config)
|
|
61
|
+
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
62
|
+
|
|
54
63
|
def process_on_device(self, device_id, api_infos, progress_queue):
|
|
55
64
|
"""
|
|
56
65
|
在特定设备上处理一部分API。
|
|
@@ -108,7 +108,8 @@ def delete_torch_paths():
|
|
|
108
108
|
|
|
109
109
|
if count_delete_env_path >= MsCompareConst.MAX_RECURSION_DEPTH - 1:
|
|
110
110
|
raise Exception(f"Please check if you have a valid PyTorch and MindTorch environment, and ensure "
|
|
111
|
-
f"the PYTHONPATH environment variable depth does not
|
|
111
|
+
f"the PYTHONPATH environment variable depth does not "
|
|
112
|
+
f"exceed {MsCompareConst.MAX_RECURSION_DEPTH}.")
|
|
112
113
|
|
|
113
114
|
|
|
114
115
|
if not is_mindtorch():
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,21 +13,50 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
from
|
|
16
|
+
from collections import OrderedDict
|
|
17
|
+
|
|
18
|
+
from mindspore import Tensor
|
|
19
|
+
from mindspore.common.hook_handle import HookHandle
|
|
20
|
+
from mindspore.ops.operations import _inner_ops as inner
|
|
21
|
+
|
|
17
22
|
from msprobe.core.common.const import Const
|
|
23
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
24
|
+
from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope, BaseScope
|
|
25
|
+
from msprobe.mindspore.common.const import Const as MsConst
|
|
26
|
+
from msprobe.mindspore.common.log import logger
|
|
27
|
+
from msprobe.mindspore.common.utils import (
|
|
28
|
+
is_mindtorch,
|
|
29
|
+
get_cells_and_names_with_index,
|
|
30
|
+
has_kwargs_in_forward_hook,
|
|
31
|
+
is_graph_mode_cell_dump_allowed
|
|
32
|
+
)
|
|
33
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
34
|
+
from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump
|
|
35
|
+
from msprobe.core.common.runtime import Runtime
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_cell_construct(construct):
|
|
39
|
+
def _construct(self, *args, **kwargs):
|
|
40
|
+
if hasattr(self, 'msprobe_hook'):
|
|
41
|
+
setattr(self, 'msprobe_input_kwargs', kwargs)
|
|
42
|
+
return construct(self, *args, **kwargs)
|
|
43
|
+
return _construct
|
|
18
44
|
|
|
19
45
|
|
|
20
46
|
class CellProcessor:
|
|
21
47
|
cell_count = {}
|
|
22
48
|
cell_stack = []
|
|
23
|
-
api_parent_node =
|
|
49
|
+
api_parent_node = None
|
|
24
50
|
module_node = {}
|
|
51
|
+
cell_bw_hook_kernels = {}
|
|
52
|
+
cell_backward_pre_hook = []
|
|
53
|
+
cell_backward_hook = []
|
|
25
54
|
|
|
26
55
|
def __init__(self, scope):
|
|
27
56
|
self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
|
|
28
57
|
|
|
29
58
|
@staticmethod
|
|
30
|
-
def
|
|
59
|
+
def set_and_get_calls_number(cell_name):
|
|
31
60
|
if cell_name not in CellProcessor.cell_count:
|
|
32
61
|
CellProcessor.cell_count[cell_name] = 0
|
|
33
62
|
else:
|
|
@@ -38,42 +67,184 @@ class CellProcessor:
|
|
|
38
67
|
def reset_cell_stats(cls):
|
|
39
68
|
cls.cell_count = {}
|
|
40
69
|
cls.cell_stack = []
|
|
41
|
-
cls.api_parent_node =
|
|
70
|
+
cls.api_parent_node = None
|
|
42
71
|
cls.module_node = {}
|
|
72
|
+
cls.cell_bw_hook_kernels = {}
|
|
73
|
+
cls.cell_backward_pre_hook = []
|
|
74
|
+
cls.cell_backward_hook = []
|
|
43
75
|
|
|
44
|
-
def
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
76
|
+
def register_cell_hook(self, models, build_hook, config: DebuggerConfig):
|
|
77
|
+
if not models:
|
|
78
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
79
|
+
'The model cannot be None, when level is "L0" or "mix"')
|
|
80
|
+
|
|
81
|
+
is_registered = False
|
|
82
|
+
model_type = Const.MODULE if is_mindtorch() else Const.CELL
|
|
83
|
+
cells_with_index_in_pynative_mode, cells_with_index_in_graph_mode = get_cells_and_names_with_index(models)
|
|
84
|
+
construct_name = '_call_impl' if is_mindtorch() else '_run_construct'
|
|
85
|
+
|
|
86
|
+
for index, cells_and_names in cells_with_index_in_pynative_mode.items():
|
|
87
|
+
model = models if index == "-1" else models[int(index)]
|
|
88
|
+
for name, cell in cells_and_names:
|
|
89
|
+
if cell == model:
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
if not has_kwargs_in_forward_hook():
|
|
93
|
+
if not hasattr(cell.__class__, 'msprobe_construct'):
|
|
94
|
+
setattr(cell.__class__, 'msprobe_construct', True)
|
|
95
|
+
if hasattr(cell.__class__, construct_name):
|
|
96
|
+
setattr(cell.__class__, construct_name,
|
|
97
|
+
get_cell_construct(getattr(cell.__class__, construct_name)))
|
|
98
|
+
setattr(cell, 'msprobe_hook', True)
|
|
99
|
+
|
|
100
|
+
cell_index = (index + Const.SEP) if index != "-1" else ""
|
|
101
|
+
prefix = f'{model_type}{Const.SEP}{cell_index}{name}{Const.SEP}{cell.__class__.__name__}{Const.SEP}'
|
|
102
|
+
|
|
103
|
+
forward_pre_hook = self.build_cell_hook(prefix, build_hook)
|
|
104
|
+
cell.register_forward_pre_hook(forward_pre_hook)
|
|
105
|
+
|
|
106
|
+
if not is_registered:
|
|
107
|
+
logger.info("The cell hook function is successfully mounted to the model.")
|
|
108
|
+
is_registered = True
|
|
109
|
+
|
|
110
|
+
if is_graph_mode_cell_dump_allowed(config):
|
|
111
|
+
cells_and_names_in_graph_mode = []
|
|
112
|
+
for index, cells_and_names in cells_with_index_in_graph_mode.items():
|
|
113
|
+
model = models if index == "-1" else models[int(index)]
|
|
114
|
+
for name, cell in cells_and_names:
|
|
115
|
+
if cell == model:
|
|
116
|
+
continue
|
|
117
|
+
cell_index = (index + Const.SEP) if index != "-1" else ""
|
|
118
|
+
cells_and_names_in_graph_mode.append((f'{cell_index}{name}', cell))
|
|
119
|
+
|
|
120
|
+
if cells_and_names_in_graph_mode:
|
|
121
|
+
Runtime.run_mode = MsConst.PYNATIVE_GRAPH_MODE
|
|
122
|
+
GraphModeCellDump(config, cells_and_names_in_graph_mode, strict=False).handle()
|
|
51
123
|
|
|
52
|
-
|
|
53
|
-
|
|
124
|
+
def build_cell_hook(self, cell_name, build_data_hook):
|
|
125
|
+
def forward_pre_hook(cell, args):
|
|
126
|
+
index = CellProcessor.set_and_get_calls_number(cell_name)
|
|
127
|
+
full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}'
|
|
128
|
+
full_backward_name = f'{cell_name}{Const.BACKWARD}{Const.SEP}{index}'
|
|
54
129
|
|
|
55
|
-
|
|
56
|
-
self.scope.begin_module(full_name)
|
|
130
|
+
self.set_construct_info_in_pre_hook(full_forward_name)
|
|
57
131
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
132
|
+
if not hasattr(cell, 'msprobe_forward_hook'):
|
|
133
|
+
if is_mindtorch():
|
|
134
|
+
cell.register_forward_hook(forward_hook, prepend=True, with_kwargs=True)
|
|
135
|
+
else:
|
|
136
|
+
forward_hook_dict = getattr(cell, '_forward_hook', OrderedDict())
|
|
137
|
+
if has_kwargs_in_forward_hook():
|
|
138
|
+
forward_hook_with_kwargs_dict = getattr(cell, '_forward_hook_with_kwargs', OrderedDict())
|
|
139
|
+
handle = HookHandle(forward_hook_dict, extra_dict=forward_hook_with_kwargs_dict)
|
|
140
|
+
forward_hook_with_kwargs_dict[handle.handle_id] = True
|
|
141
|
+
else:
|
|
142
|
+
handle = HookHandle(forward_hook_dict)
|
|
143
|
+
forward_hook_dict[handle.handle_id] = forward_hook
|
|
144
|
+
forward_hook_dict.move_to_end(handle.handle_id, last=False)
|
|
145
|
+
|
|
146
|
+
setattr(cell, 'msprobe_forward_hook', True)
|
|
147
|
+
|
|
148
|
+
def get_backward_hook(backward_data_hook, full_backward_name):
|
|
149
|
+
def backward_hook_fn(cell, grad_input, grad_output):
|
|
150
|
+
new_output = backward_data_hook(cell, grad_input, grad_output)
|
|
151
|
+
self.set_construct_info_in_hook(full_backward_name)
|
|
152
|
+
cell.has_pre_hook_called = False
|
|
153
|
+
return new_output
|
|
154
|
+
return backward_hook_fn
|
|
155
|
+
|
|
156
|
+
enable_hooked = sum(
|
|
157
|
+
[isinstance(ele, Tensor) and ele.dtype not in MsConst.NonDifferentiableType for ele in args]
|
|
158
|
+
)
|
|
159
|
+
if enable_hooked:
|
|
160
|
+
backward_hook = OrderedDict()
|
|
161
|
+
hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name)
|
|
162
|
+
backward_hook[full_backward_name] = get_backward_hook(hook_set.backward_hook, full_backward_name)
|
|
163
|
+
CellProcessor.cell_backward_hook.append(backward_hook)
|
|
164
|
+
bw_hook = inner.CellBackwardHook(full_backward_name, cell,
|
|
165
|
+
self.cell_backward_hook[-1])
|
|
166
|
+
bw_hook.register_backward_hook()
|
|
167
|
+
CellProcessor.cell_bw_hook_kernels[full_forward_name] = bw_hook
|
|
168
|
+
|
|
169
|
+
args = bw_hook(*args)
|
|
170
|
+
|
|
171
|
+
return args
|
|
172
|
+
|
|
173
|
+
def forward_hook(cell, args, kwargs_or_output, output_or_kwargs=None):
|
|
174
|
+
index = CellProcessor.cell_count.get(cell_name, 0)
|
|
175
|
+
full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}'
|
|
176
|
+
full_backward_name = f'{cell_name}{Const.BACKWARD}{Const.SEP}{index}'
|
|
177
|
+
|
|
178
|
+
self.set_construct_info_in_hook(full_forward_name)
|
|
179
|
+
|
|
180
|
+
hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name)
|
|
181
|
+
hook_result = hook_set.forward_hook(cell, args, kwargs_or_output, output_or_kwargs)
|
|
182
|
+
if hook_result is not None:
|
|
183
|
+
outputs = hook_result
|
|
63
184
|
else:
|
|
64
|
-
|
|
185
|
+
outputs = output_or_kwargs if has_kwargs_in_forward_hook() else kwargs_or_output
|
|
186
|
+
|
|
187
|
+
bw_hook = CellProcessor.cell_bw_hook_kernels.get(full_forward_name)
|
|
188
|
+
if bw_hook:
|
|
189
|
+
if not isinstance(outputs, (Tensor, tuple)):
|
|
190
|
+
logger.warning("For backward hooks to be called,"
|
|
191
|
+
" cell output should be a Tensor or a tuple of Tensors"
|
|
192
|
+
f" but received {type(outputs)}")
|
|
193
|
+
if isinstance(outputs, tuple):
|
|
194
|
+
new_outputs = bw_hook(*outputs)
|
|
195
|
+
else:
|
|
196
|
+
new_outputs = bw_hook(outputs)
|
|
197
|
+
if isinstance(outputs, tuple) and len(outputs) == 1:
|
|
198
|
+
new_outputs = (new_outputs,)
|
|
199
|
+
outputs = new_outputs
|
|
200
|
+
|
|
201
|
+
def get_backward_pre_hook(full_backward_name, backward_data_hook):
|
|
202
|
+
def backward_pre_hook_fn(cell, grad_output):
|
|
203
|
+
cell.has_pre_hook_called = True
|
|
204
|
+
self.set_construct_info_in_pre_hook(full_backward_name)
|
|
205
|
+
if backward_data_hook:
|
|
206
|
+
backward_data_hook(cell, (), grad_output)
|
|
207
|
+
self.set_construct_info_in_hook(full_backward_name)
|
|
208
|
+
cell.has_pre_hook_called = False
|
|
209
|
+
return backward_pre_hook_fn
|
|
65
210
|
|
|
66
|
-
|
|
67
|
-
|
|
211
|
+
backward_pre_hook = OrderedDict()
|
|
212
|
+
backward_data_hook = None if bw_hook else hook_set.backward_hook
|
|
213
|
+
backward_pre_hook[full_backward_name] = get_backward_pre_hook(full_backward_name, backward_data_hook)
|
|
214
|
+
CellProcessor.cell_backward_pre_hook.append(backward_pre_hook)
|
|
215
|
+
bw_pre_hook = inner.CellBackwardHook(full_backward_name, cell,
|
|
216
|
+
self.cell_backward_pre_hook[-1])
|
|
217
|
+
bw_pre_hook.register_backward_pre_hook()
|
|
68
218
|
|
|
69
|
-
|
|
219
|
+
if isinstance(outputs, tuple):
|
|
220
|
+
result = bw_pre_hook(*outputs)
|
|
221
|
+
else:
|
|
222
|
+
result = bw_pre_hook(outputs)
|
|
223
|
+
if isinstance(outputs, tuple):
|
|
224
|
+
if len(outputs) == 1:
|
|
225
|
+
result = (result,)
|
|
226
|
+
if len(result) != len(outputs):
|
|
227
|
+
raise TypeError(
|
|
228
|
+
f"The backward pre hook return value size is {len(result)} "
|
|
229
|
+
f"not equal to output size {len(outputs)}"
|
|
230
|
+
)
|
|
231
|
+
return result
|
|
232
|
+
|
|
233
|
+
return forward_pre_hook
|
|
70
234
|
|
|
71
|
-
def
|
|
72
|
-
if
|
|
73
|
-
|
|
235
|
+
def set_construct_info_in_pre_hook(self, full_name):
|
|
236
|
+
if self.cell_stack:
|
|
237
|
+
CellProcessor.module_node[full_name] = self.cell_stack[-1]
|
|
74
238
|
else:
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
239
|
+
CellProcessor.module_node[full_name] = None
|
|
240
|
+
CellProcessor.cell_stack.append(full_name)
|
|
241
|
+
CellProcessor.api_parent_node = full_name
|
|
242
|
+
if self.scope:
|
|
243
|
+
self.scope.begin_module(full_name)
|
|
244
|
+
|
|
245
|
+
def set_construct_info_in_hook(self, full_name):
|
|
246
|
+
if self.cell_stack:
|
|
247
|
+
CellProcessor.cell_stack.pop()
|
|
248
|
+
CellProcessor.api_parent_node = CellProcessor.cell_stack[-1] if self.cell_stack else None
|
|
249
|
+
if self.scope:
|
|
250
|
+
self.scope.end_module(full_name)
|
|
@@ -34,19 +34,6 @@ class Parser:
|
|
|
34
34
|
if isinstance(subgraph_node.attrs, list):
|
|
35
35
|
subgraph_node.attrs.extend(attrs)
|
|
36
36
|
|
|
37
|
-
@staticmethod
|
|
38
|
-
def parse_graph_attributes(text: str, graph_node: GraphNode) -> None:
|
|
39
|
-
attr_pattern = re.compile(r'# Attrs:\s*(.*)', re.DOTALL)
|
|
40
|
-
match = attr_pattern.search(text, graph_node.pos)
|
|
41
|
-
if match:
|
|
42
|
-
attrs = match.group(1).strip().split('\n')
|
|
43
|
-
for attr in attrs:
|
|
44
|
-
if not attr:
|
|
45
|
-
break
|
|
46
|
-
key, value = attr.split(':')
|
|
47
|
-
if isinstance(graph_node.attrs, dict):
|
|
48
|
-
graph_node.attrs[key.strip()] = value.strip()
|
|
49
|
-
|
|
50
37
|
@staticmethod
|
|
51
38
|
def parse_code_info(text: str, start_pos: int, end_pos: int) -> List[str]:
|
|
52
39
|
code_info = []
|
|
@@ -124,8 +111,9 @@ class Parser:
|
|
|
124
111
|
scope_match = scope_pattern.search(text, end_pos)
|
|
125
112
|
scope = scope_match.group(1) if scope_match else ""
|
|
126
113
|
|
|
127
|
-
id_pattern = re.compile(
|
|
128
|
-
|
|
114
|
+
id_pattern = re.compile(
|
|
115
|
+
r'cnode_primal_attrs:'r'\s*\{[\w+]{1, 10000}\b(?:forward_unique_id|unique_id):\s*\"(\d+)\"',
|
|
116
|
+
re.IGNORECASE)
|
|
129
117
|
unique_id_match = id_pattern.search(text, end_pos, scope_match.start())
|
|
130
118
|
unique_id = unique_id_match.group(1) if unique_id_match else None
|
|
131
119
|
|
|
@@ -186,7 +174,7 @@ class Parser:
|
|
|
186
174
|
node_info.var_inputs.append(callee_name)
|
|
187
175
|
|
|
188
176
|
def parse_subgraphs(self, text: str) -> None:
|
|
189
|
-
subgraph_pattern = re.compile(r'subgraph\s+@(\
|
|
177
|
+
subgraph_pattern = re.compile(r'/subgraph\s+@([\w+]{1,1000)(\([^\)]{1,100}\))?\s+\S[^\{]\{/+')
|
|
190
178
|
matches = list(subgraph_pattern.finditer(text))
|
|
191
179
|
end_pos = 0
|
|
192
180
|
for match in matches:
|
|
@@ -203,11 +191,6 @@ class Parser:
|
|
|
203
191
|
subgraph_info.end = end_pos
|
|
204
192
|
logging.info('Parsed subgraph: %s', subgraph_name)
|
|
205
193
|
|
|
206
|
-
def count_nodes(self) -> Tuple[int, int]:
|
|
207
|
-
total_nodes = len(self.nodes)
|
|
208
|
-
total_cnodes = sum(1 for node in self.nodes.values() if node.name.startswith('CNode'))
|
|
209
|
-
return total_nodes, total_cnodes
|
|
210
|
-
|
|
211
194
|
def create_backward_map(self):
|
|
212
195
|
for node in self.nodes.values():
|
|
213
196
|
if node.scope and node.scope.startswith("Gradients"):
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
import mindspore as ms
|
|
18
|
+
from mindspore import dtype as mstype
|
|
18
19
|
|
|
19
20
|
from msprobe.core.common.const import Const as CoreConst
|
|
20
21
|
|
|
@@ -23,14 +24,20 @@ class Const:
|
|
|
23
24
|
CELL = "cell"
|
|
24
25
|
API = "api"
|
|
25
26
|
KERNEL = "kernel"
|
|
27
|
+
CELL_AND_API = 'cell_and_api'
|
|
26
28
|
TOOL_LEVEL_DICT = {
|
|
27
29
|
CoreConst.LEVEL_L0: CELL,
|
|
28
30
|
CoreConst.LEVEL_L1: API,
|
|
29
|
-
CoreConst.LEVEL_L2: KERNEL
|
|
31
|
+
CoreConst.LEVEL_L2: KERNEL,
|
|
32
|
+
CoreConst.LEVEL_MIX: CELL_AND_API
|
|
30
33
|
}
|
|
31
|
-
|
|
34
|
+
|
|
35
|
+
PYNATIVE_MODE = CoreConst.PYNATIVE_MODE
|
|
36
|
+
GRAPH_MODE = "graph"
|
|
32
37
|
GRAPH_GE_MODE = "graph_ge"
|
|
33
38
|
GRAPH_KBYK_MODE = "graph_kbyk"
|
|
39
|
+
PYNATIVE_GRAPH_MODE = CoreConst.PYNATIVE_GRAPH_MODE
|
|
40
|
+
|
|
34
41
|
JIT_LEVEL = "jit_level"
|
|
35
42
|
JIT_LEVEL_O0 = "O0"
|
|
36
43
|
JIT_LEVEL_O1 = "O1"
|
|
@@ -61,6 +68,7 @@ class Const:
|
|
|
61
68
|
DROPOUT_API_NAME_PREFIX = "dropout"
|
|
62
69
|
|
|
63
70
|
GRAPH_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.INPUT, CoreConst.OUTPUT]
|
|
71
|
+
GRAPH_CELL_DUMP_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.FORWARD, CoreConst.BACKWARD]
|
|
64
72
|
|
|
65
73
|
HOOK_MS_PREFIX_DICT = {
|
|
66
74
|
OPS_DATA_PREFIX: OPS_PREFIX,
|
|
@@ -69,6 +77,13 @@ class Const:
|
|
|
69
77
|
MINT_NN_FUNC_DATA_PREFIX: MINT_NN_FUNC_PREFIX
|
|
70
78
|
}
|
|
71
79
|
|
|
80
|
+
NonDifferentiableType = (
|
|
81
|
+
mstype.bool_, mstype.int8, mstype.byte, mstype.uint8, mstype.ubyte,
|
|
82
|
+
mstype.int16, mstype.short, mstype.uint16, mstype.ushort,
|
|
83
|
+
mstype.int32, mstype.intc, mstype.uint32, mstype.uintc,
|
|
84
|
+
mstype.int64, mstype.intp, mstype.uint64, mstype.uintp
|
|
85
|
+
)
|
|
86
|
+
|
|
72
87
|
|
|
73
88
|
class MsCompareConst:
|
|
74
89
|
# api_info field
|
|
@@ -88,14 +103,11 @@ class MsCompareConst:
|
|
|
88
103
|
MINDTORCH_NPU = "NPU"
|
|
89
104
|
MINDTORCH_DIST = "Distributed"
|
|
90
105
|
|
|
91
|
-
|
|
92
|
-
|
|
93
106
|
MT_VALID_API_TYPES = [
|
|
94
107
|
MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR
|
|
95
108
|
]
|
|
96
109
|
SUPPORTED_FUSION_LIST = ["flash_attention_score"]
|
|
97
110
|
|
|
98
|
-
|
|
99
111
|
TASK_FIELD = "task"
|
|
100
112
|
STATISTICS_TASK = "statistics"
|
|
101
113
|
FRAMEWORK = "framework"
|
|
@@ -129,8 +141,6 @@ class MsCompareConst:
|
|
|
129
141
|
EXCEPTION_SKIP = "exception_skip"
|
|
130
142
|
|
|
131
143
|
|
|
132
|
-
|
|
133
|
-
|
|
134
144
|
class FreeBenchmarkConst:
|
|
135
145
|
ADD_NOISE = "add_noise"
|
|
136
146
|
BIT_NOISE = "bit_noise"
|
|
@@ -13,19 +13,34 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import inspect
|
|
16
17
|
import os
|
|
17
18
|
import random
|
|
19
|
+
import types
|
|
18
20
|
|
|
19
21
|
import mindspore as ms
|
|
20
|
-
|
|
21
22
|
from mindspore import ops
|
|
23
|
+
from mindspore.common.jit_config import JitConfig
|
|
22
24
|
from mindspore.mint import nn
|
|
23
25
|
|
|
26
|
+
from msprobe.core.common.const import Const
|
|
27
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
24
28
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
25
29
|
from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy
|
|
26
30
|
from msprobe.core.common.log import logger
|
|
27
|
-
from msprobe.core.common.const import Const
|
|
28
31
|
from msprobe.core.common.utils import CompareException, check_seed_all, is_save_variable_valid
|
|
32
|
+
from msprobe.mindspore.common.const import Const as MsConst
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from mindspore._c_expression import _set_init_iter
|
|
36
|
+
except ImportError:
|
|
37
|
+
enable_dynamic_kbyk_dump = False
|
|
38
|
+
else:
|
|
39
|
+
enable_dynamic_kbyk_dump = True
|
|
40
|
+
|
|
41
|
+
mindtorch_check_result = None
|
|
42
|
+
register_backward_hook_functions = {}
|
|
43
|
+
kwargs_exist_in_forward_hook = None
|
|
29
44
|
|
|
30
45
|
|
|
31
46
|
class MsprobeStep(ms.train.Callback):
|
|
@@ -33,6 +48,11 @@ class MsprobeStep(ms.train.Callback):
|
|
|
33
48
|
super(MsprobeStep, self).__init__()
|
|
34
49
|
self.debugger = debugger
|
|
35
50
|
|
|
51
|
+
def on_train_begin(self, run_context):
|
|
52
|
+
self.debugger.start()
|
|
53
|
+
if enable_dynamic_kbyk_dump:
|
|
54
|
+
_set_init_iter(0)
|
|
55
|
+
|
|
36
56
|
def on_train_step_begin(self, run_context):
|
|
37
57
|
self.debugger.start()
|
|
38
58
|
|
|
@@ -82,8 +102,8 @@ def convert_to_int(value):
|
|
|
82
102
|
|
|
83
103
|
|
|
84
104
|
def clean_input_kwargs(cell):
|
|
85
|
-
if hasattr(cell, '
|
|
86
|
-
del cell.
|
|
105
|
+
if hasattr(cell, 'msprobe_input_kwargs'):
|
|
106
|
+
del cell.msprobe_input_kwargs
|
|
87
107
|
|
|
88
108
|
|
|
89
109
|
def list_lowest_level_directories(root_dir):
|
|
@@ -152,9 +172,6 @@ def remove_dropout():
|
|
|
152
172
|
nn.functional.dropout = dropout_ext
|
|
153
173
|
|
|
154
174
|
|
|
155
|
-
mindtorch_check_result = None
|
|
156
|
-
|
|
157
|
-
|
|
158
175
|
def is_mindtorch():
|
|
159
176
|
global mindtorch_check_result
|
|
160
177
|
if mindtorch_check_result is None:
|
|
@@ -169,11 +186,11 @@ def is_mindtorch():
|
|
|
169
186
|
return mindtorch_check_result
|
|
170
187
|
|
|
171
188
|
|
|
172
|
-
register_backward_hook_functions = {}
|
|
173
|
-
|
|
174
|
-
|
|
175
189
|
def set_register_backward_hook_functions():
|
|
176
190
|
global register_backward_hook_functions
|
|
191
|
+
if register_backward_hook_functions:
|
|
192
|
+
return
|
|
193
|
+
|
|
177
194
|
if is_mindtorch():
|
|
178
195
|
import torch
|
|
179
196
|
from msprobe.mindspore.mindtorch import (_call_impl,
|
|
@@ -192,7 +209,7 @@ def set_register_backward_hook_functions():
|
|
|
192
209
|
|
|
193
210
|
def check_save_param(variable, name, save_backward):
|
|
194
211
|
# try catch this api to skip invalid call
|
|
195
|
-
valid_data_types =
|
|
212
|
+
valid_data_types = (ms.Tensor, int, float, str)
|
|
196
213
|
if not is_save_variable_valid(variable, valid_data_types):
|
|
197
214
|
valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list)
|
|
198
215
|
logger.warning("PrecisionDebugger.save variable type not valid, "
|
|
@@ -209,3 +226,103 @@ def check_save_param(variable, name, save_backward):
|
|
|
209
226
|
"should be bool. "
|
|
210
227
|
"Skip current save process.")
|
|
211
228
|
raise ValueError
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def is_graph_mode_cell_dump_allowed(config):
|
|
232
|
+
if config.task not in [Const.TENSOR, Const.STATISTICS] or is_mindtorch() or not hasattr(ops, 'DumpGradient'):
|
|
233
|
+
return False
|
|
234
|
+
valid_mix_level = [MsConst.CELL_AND_API, Const.LEVEL_MIX]
|
|
235
|
+
if config.level in valid_mix_level and config.execution_mode == MsConst.PYNATIVE_MODE:
|
|
236
|
+
return True
|
|
237
|
+
return config.level == MsConst.CELL or config.level == Const.LEVEL_L0
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@recursion_depth_decorator('msprobe.mindspore.common.utils.is_decorated_by_jit')
|
|
241
|
+
def is_decorated_by_jit(func):
|
|
242
|
+
closure = getattr(func, '__closure__', [])
|
|
243
|
+
if closure:
|
|
244
|
+
for obj in closure:
|
|
245
|
+
if isinstance(obj.cell_contents, JitConfig):
|
|
246
|
+
return True
|
|
247
|
+
elif isinstance(obj.cell_contents, types.FunctionType) and hasattr(obj.cell_contents, '__closure__'):
|
|
248
|
+
if is_decorated_by_jit(obj.cell_contents):
|
|
249
|
+
return True
|
|
250
|
+
return False
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
@recursion_depth_decorator('msprobe.mindspore.common.utils.get_cells_and_names')
|
|
254
|
+
def get_cells_and_names(model, cells_set=None, name_prefix=''):
|
|
255
|
+
cells_set = cells_set if cells_set else set()
|
|
256
|
+
if model in cells_set:
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
cells_set.add(model)
|
|
260
|
+
jit_decorated = is_decorated_by_jit(model.construct)
|
|
261
|
+
yield name_prefix, model, jit_decorated
|
|
262
|
+
if jit_decorated:
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
children_cells = getattr(model, '_cells')
|
|
266
|
+
for name, cell in children_cells.items():
|
|
267
|
+
if cell:
|
|
268
|
+
cells_name_prefix = f'{name_prefix}{Const.SEP}{name}' if name_prefix else name
|
|
269
|
+
jit_decorated = is_decorated_by_jit(model.construct)
|
|
270
|
+
if jit_decorated:
|
|
271
|
+
yield cells_name_prefix, cell, jit_decorated
|
|
272
|
+
else:
|
|
273
|
+
for ele in get_cells_and_names(cell, cells_set, cells_name_prefix):
|
|
274
|
+
yield ele
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def get_cells_and_names_with_index(models):
|
|
278
|
+
cells_with_index_in_pynative_mode = {}
|
|
279
|
+
cells_with_index_in_graph_mode = {}
|
|
280
|
+
|
|
281
|
+
def distinguish_cells(cells):
|
|
282
|
+
cells_in_pynative_mode = []
|
|
283
|
+
cells_in_graph_mode = []
|
|
284
|
+
for name, cell, jit_decorated in cells:
|
|
285
|
+
if jit_decorated:
|
|
286
|
+
cells_in_graph_mode.append((name, cell))
|
|
287
|
+
else:
|
|
288
|
+
cells_in_pynative_mode.append((name, cell))
|
|
289
|
+
return cells_in_pynative_mode, cells_in_graph_mode
|
|
290
|
+
|
|
291
|
+
if is_mindtorch():
|
|
292
|
+
if isinstance(models, (list, tuple)):
|
|
293
|
+
for index, model in enumerate(models):
|
|
294
|
+
cells_with_index_in_pynative_mode[str(index)] = model.named_modules()
|
|
295
|
+
else:
|
|
296
|
+
cells_with_index_in_pynative_mode["-1"] = models.named_modules()
|
|
297
|
+
else:
|
|
298
|
+
if isinstance(models, (list, tuple)):
|
|
299
|
+
for index, model in enumerate(models):
|
|
300
|
+
cells = get_cells_and_names(model)
|
|
301
|
+
cells_in_pynative_mode, cells_in_graph_mode = distinguish_cells(cells)
|
|
302
|
+
cells_with_index_in_pynative_mode[str(index)] = cells_in_pynative_mode
|
|
303
|
+
cells_with_index_in_graph_mode[str(index)] = cells_in_graph_mode
|
|
304
|
+
else:
|
|
305
|
+
cells = get_cells_and_names(models)
|
|
306
|
+
cells_in_pynative_mode, cells_in_graph_mode = distinguish_cells(cells)
|
|
307
|
+
cells_with_index_in_pynative_mode["-1"] = cells_in_pynative_mode
|
|
308
|
+
cells_with_index_in_graph_mode["-1"] = cells_in_graph_mode
|
|
309
|
+
|
|
310
|
+
return cells_with_index_in_pynative_mode, cells_with_index_in_graph_mode
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def has_kwargs_in_forward_hook():
|
|
314
|
+
global kwargs_exist_in_forward_hook
|
|
315
|
+
|
|
316
|
+
if kwargs_exist_in_forward_hook is None:
|
|
317
|
+
if is_mindtorch():
|
|
318
|
+
kwargs_exist_in_forward_hook = True
|
|
319
|
+
return kwargs_exist_in_forward_hook
|
|
320
|
+
|
|
321
|
+
try:
|
|
322
|
+
func_params = inspect.signature(nn.Cell.register_forward_hook).parameters
|
|
323
|
+
kwargs_exist_in_forward_hook = 'with_kwargs' in func_params
|
|
324
|
+
except Exception:
|
|
325
|
+
kwargs_exist_in_forward_hook = False
|
|
326
|
+
return kwargs_exist_in_forward_hook
|
|
327
|
+
|
|
328
|
+
return kwargs_exist_in_forward_hook
|