mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
- mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +16 -21
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +185 -11
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +4 -0
- msprobe/core/common/utils.py +42 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +381 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +199 -69
- msprobe/core/data_dump/data_collector.py +100 -25
- msprobe/core/data_dump/data_processor/base.py +130 -28
- msprobe/core/data_dump/data_processor/factory.py +8 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
- msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
- msprobe/core/data_dump/json_writer.py +54 -8
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +121 -17
- msprobe/docs/02.config_introduction.md +18 -16
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +107 -58
- msprobe/docs/06.data_dump_MindSpore.md +95 -34
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +310 -220
- msprobe/docs/21.visualization_PyTorch.md +125 -35
- msprobe/docs/22.visualization_MindSpore.md +149 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +525 -0
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +11 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +68 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +28 -2
- msprobe/mindspore/debugger/precision_debugger.py +100 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +13 -3
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +347 -107
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -7
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +116 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +42 -17
- msprobe/pytorch/debugger/precision_debugger.py +56 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +98 -28
- msprobe/pytorch/monitor/csv2tb.py +164 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +543 -318
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +76 -56
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +84 -48
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +264 -115
- msprobe/visualization/builder/graph_builder.py +93 -10
- msprobe/visualization/builder/msprobe_adapter.py +30 -6
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +15 -19
- msprobe/visualization/graph/distributed_analyzer.py +395 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +100 -27
- msprobe/visualization/utils.py +24 -31
- mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
- msprobe/pytorch/functional/module_dump.py +0 -84
- msprobe/pytorch/module_processer.py +0 -150
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -14,15 +14,20 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
-
from collections import defaultdict
|
|
17
|
+
from collections import defaultdict, namedtuple
|
|
18
18
|
|
|
19
19
|
import mindspore as ms
|
|
20
20
|
from mindspore._c_expression import MSContext
|
|
21
21
|
|
|
22
|
-
from msprobe.core.common.const import Const, MsgConst
|
|
22
|
+
from msprobe.core.common.const import Const, FileCheckConst, MsgConst
|
|
23
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
24
|
+
from msprobe.core.common.file_utils import FileChecker
|
|
25
|
+
from msprobe.core.common.utils import get_real_step_or_rank
|
|
23
26
|
from msprobe.mindspore.cell_processor import CellProcessor
|
|
24
27
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
28
|
+
from msprobe.mindspore.common.utils import set_register_backward_hook_functions, check_save_param
|
|
25
29
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
30
|
+
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
26
31
|
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
27
32
|
from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor
|
|
28
33
|
from msprobe.mindspore.ms_config import parse_json_config
|
|
@@ -30,12 +35,21 @@ from msprobe.mindspore.runtime import Runtime
|
|
|
30
35
|
from msprobe.mindspore.service import Service
|
|
31
36
|
from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
|
|
32
37
|
|
|
38
|
+
try:
|
|
39
|
+
from msprobe.lib import _msprobe_c
|
|
40
|
+
except ImportError:
|
|
41
|
+
_msprobe_c = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task", "dump_path", "level"])
|
|
45
|
+
|
|
33
46
|
|
|
34
47
|
class PrecisionDebugger:
|
|
35
48
|
_instance = None
|
|
36
49
|
task_not_need_service = [Const.GRAD_PROBE]
|
|
37
50
|
|
|
38
|
-
def __new__(cls, config_path=None,
|
|
51
|
+
def __new__(cls, config_path=None, task=None, dump_path=None,
|
|
52
|
+
level=None, step=None, opt=None):
|
|
39
53
|
if not cls._instance:
|
|
40
54
|
cls._instance = super().__new__(cls)
|
|
41
55
|
cls._instance.initialized = False
|
|
@@ -44,22 +58,66 @@ class PrecisionDebugger:
|
|
|
44
58
|
cls.first_start = False
|
|
45
59
|
return cls._instance
|
|
46
60
|
|
|
47
|
-
def __init__(self, config_path=None
|
|
61
|
+
def __init__(self, config_path=None, task=None, dump_path=None,
|
|
62
|
+
level=None, step=None):
|
|
48
63
|
if self.initialized:
|
|
49
64
|
return
|
|
50
65
|
self.initialized = True
|
|
66
|
+
|
|
67
|
+
set_register_backward_hook_functions()
|
|
68
|
+
|
|
51
69
|
if not config_path:
|
|
52
70
|
config_path = os.path.join(os.path.dirname(__file__), "../../config.json")
|
|
71
|
+
|
|
72
|
+
config_params = ConfigParameters(config_path, task, dump_path, level)
|
|
73
|
+
self.check_input_params(config_params)
|
|
74
|
+
|
|
53
75
|
common_config, task_config = parse_json_config(config_path)
|
|
76
|
+
common_config.task = task if task else common_config.task
|
|
54
77
|
self.task = common_config.task
|
|
55
78
|
if self.task == Const.GRAD_PROBE:
|
|
56
79
|
self.gm = GradientMonitor(common_config, task_config)
|
|
57
80
|
return
|
|
81
|
+
common_config.step = get_real_step_or_rank(
|
|
82
|
+
step, Const.STEP) if step is not None else common_config.step
|
|
83
|
+
common_config.level = level if level else common_config.level
|
|
84
|
+
common_config.dump_path = dump_path if dump_path else common_config.dump_path
|
|
58
85
|
self.config = DebuggerConfig(common_config, task_config)
|
|
59
86
|
|
|
87
|
+
if _msprobe_c:
|
|
88
|
+
_msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path)
|
|
89
|
+
|
|
90
|
+
self.config.execution_mode = self._get_execution_mode()
|
|
91
|
+
if self._need_service():
|
|
92
|
+
self.config.check_config_with_l2()
|
|
93
|
+
self.service = Service(self.config)
|
|
94
|
+
|
|
60
95
|
Runtime.step_count = 0
|
|
61
96
|
Runtime.is_running = False
|
|
62
97
|
|
|
98
|
+
@staticmethod
|
|
99
|
+
def check_input_params(args):
|
|
100
|
+
if args.config_path is not None:
|
|
101
|
+
if not isinstance(args.config_path, str):
|
|
102
|
+
raise MsprobeException(
|
|
103
|
+
MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
|
|
104
|
+
file_checker = FileChecker(
|
|
105
|
+
file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
106
|
+
file_checker.common_check()
|
|
107
|
+
|
|
108
|
+
if args.task is not None and args.task not in Const.TASK_LIST:
|
|
109
|
+
raise MsprobeException(
|
|
110
|
+
MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}")
|
|
111
|
+
|
|
112
|
+
if args.dump_path is not None:
|
|
113
|
+
if not isinstance(args.dump_path, str):
|
|
114
|
+
raise MsprobeException(
|
|
115
|
+
MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string")
|
|
116
|
+
|
|
117
|
+
if args.level is not None and args.level not in Const.LEVEL_LIST:
|
|
118
|
+
raise MsprobeException(
|
|
119
|
+
MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
|
|
120
|
+
|
|
63
121
|
@staticmethod
|
|
64
122
|
def _get_execution_mode():
|
|
65
123
|
jit_level = ms.context.get_jit_config().get(MsConst.JIT_LEVEL)
|
|
@@ -78,11 +136,23 @@ class PrecisionDebugger:
|
|
|
78
136
|
else:
|
|
79
137
|
return MsConst.PYNATIVE_MODE
|
|
80
138
|
|
|
139
|
+
@staticmethod
|
|
140
|
+
def _is_graph_dump(config):
|
|
141
|
+
if config.level != MsConst.KERNEL:
|
|
142
|
+
return False
|
|
143
|
+
if not config.list:
|
|
144
|
+
return True
|
|
145
|
+
is_graph = any(item.startswith("name-regex") for item in config.list)
|
|
146
|
+
is_graph |= all("." not in item for item in config.list)
|
|
147
|
+
return is_graph
|
|
148
|
+
|
|
81
149
|
@classmethod
|
|
82
150
|
def start(cls, model=None):
|
|
83
151
|
instance = cls._instance
|
|
84
152
|
if not instance:
|
|
85
153
|
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
154
|
+
if _msprobe_c:
|
|
155
|
+
_msprobe_c._PrecisionDebugger().start()
|
|
86
156
|
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
87
157
|
return
|
|
88
158
|
|
|
@@ -93,6 +163,7 @@ class PrecisionDebugger:
|
|
|
93
163
|
instance.service.start(model)
|
|
94
164
|
else:
|
|
95
165
|
if not instance.first_start:
|
|
166
|
+
api_register.api_set_ori_func()
|
|
96
167
|
handler = TaskHandlerFactory.create(instance.config)
|
|
97
168
|
handler.handle()
|
|
98
169
|
|
|
@@ -102,18 +173,15 @@ class PrecisionDebugger:
|
|
|
102
173
|
@classmethod
|
|
103
174
|
def forward_backward_dump_end(cls):
|
|
104
175
|
instance = cls._instance
|
|
105
|
-
|
|
106
|
-
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
107
|
-
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
108
|
-
return
|
|
109
|
-
if instance.service:
|
|
110
|
-
instance.service.forward_backward_dump_end()
|
|
176
|
+
instance.stop()
|
|
111
177
|
|
|
112
178
|
@classmethod
|
|
113
179
|
def stop(cls):
|
|
114
180
|
instance = cls._instance
|
|
115
181
|
if not instance:
|
|
116
182
|
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
183
|
+
if _msprobe_c:
|
|
184
|
+
_msprobe_c._PrecisionDebugger().stop()
|
|
117
185
|
if instance.task == Const.GRAD_PROBE:
|
|
118
186
|
instance.gm.stop()
|
|
119
187
|
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
@@ -127,6 +195,8 @@ class PrecisionDebugger:
|
|
|
127
195
|
instance = cls._instance
|
|
128
196
|
if not instance:
|
|
129
197
|
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
198
|
+
if _msprobe_c:
|
|
199
|
+
_msprobe_c._PrecisionDebugger().step()
|
|
130
200
|
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
131
201
|
return
|
|
132
202
|
if instance.service:
|
|
@@ -145,6 +215,24 @@ class PrecisionDebugger:
|
|
|
145
215
|
return
|
|
146
216
|
instance.gm.monitor(opt)
|
|
147
217
|
|
|
218
|
+
@classmethod
|
|
219
|
+
def save(cls, variable, name, save_backward=True):
|
|
220
|
+
instance = cls._instance
|
|
221
|
+
if not instance:
|
|
222
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
223
|
+
if instance.task not in [Const.TENSOR, Const.STATISTICS] or instance.config.level_ori != Const.LEVEL_DEBUG:
|
|
224
|
+
return
|
|
225
|
+
try:
|
|
226
|
+
check_save_param(variable, name, save_backward)
|
|
227
|
+
except ValueError:
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
instance.config.execution_mode = cls._get_execution_mode()
|
|
231
|
+
if cls._need_service():
|
|
232
|
+
if not instance.service:
|
|
233
|
+
instance.service = Service(instance.config)
|
|
234
|
+
instance.service.save(variable, name, save_backward)
|
|
235
|
+
|
|
148
236
|
@classmethod
|
|
149
237
|
def _need_service(cls):
|
|
150
238
|
instance = cls._instance
|
|
@@ -153,4 +241,4 @@ class PrecisionDebugger:
|
|
|
153
241
|
if instance.config.execution_mode != MsConst.PYNATIVE_MODE:
|
|
154
242
|
return False
|
|
155
243
|
else:
|
|
156
|
-
return instance.config.task != Const.FREE_BENCHMARK and instance.config
|
|
244
|
+
return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
2
3
|
#
|
|
3
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,7 +12,6 @@
|
|
|
11
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
13
|
# See the License for the specific language governing permissions and
|
|
13
14
|
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
15
|
|
|
16
16
|
from mindspore import Tensor, ops, mint
|
|
17
17
|
from mindspore.mint.nn import functional
|
|
@@ -20,8 +20,15 @@ from mindspore.communication import comm_func
|
|
|
20
20
|
|
|
21
21
|
from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTensor, HOOKFunctionalOP,
|
|
22
22
|
HOOKMintOP, HOOKMintNNFunctionalOP, HOOKDistributedOP,
|
|
23
|
-
|
|
23
|
+
HOOKTorchOP, HOOKTorchTensor, HOOKTorchFunctionalOP,
|
|
24
|
+
HOOKTorchDistributedOP, HOOKTorchNpuOP,
|
|
25
|
+
get_wrap_api_list, get_wrap_torch_api_list, setup_hooks)
|
|
24
26
|
from msprobe.core.common.utils import Const
|
|
27
|
+
from msprobe.mindspore.common.utils import is_mindtorch
|
|
28
|
+
|
|
29
|
+
if is_mindtorch():
|
|
30
|
+
import torch
|
|
31
|
+
import torch_npu
|
|
25
32
|
|
|
26
33
|
|
|
27
34
|
def stub_method(method):
|
|
@@ -40,6 +47,12 @@ class ApiRegistry:
|
|
|
40
47
|
self.distributed_ori_attr = {}
|
|
41
48
|
self.norm_inner_ops_ori_attr = {}
|
|
42
49
|
|
|
50
|
+
self.torch_ori_attr = {}
|
|
51
|
+
self.torch_tensor_ori_attr = {}
|
|
52
|
+
self.torch_functional_ori_attr = {}
|
|
53
|
+
self.torch_distributed_ori_attr = {}
|
|
54
|
+
self.torch_npu_ori_attr = {}
|
|
55
|
+
|
|
43
56
|
self.tensor_hook_attr = {}
|
|
44
57
|
self.stub_tensor_hook_attr = {}
|
|
45
58
|
self.functional_hook_attr = {}
|
|
@@ -48,6 +61,12 @@ class ApiRegistry:
|
|
|
48
61
|
self.distibuted_hook_attr = {}
|
|
49
62
|
self.norm_inner_ops_hook_attr = {}
|
|
50
63
|
|
|
64
|
+
self.torch_hook_attr = {}
|
|
65
|
+
self.torch_tensor_hook_attr = {}
|
|
66
|
+
self.torch_functional_hook_attr = {}
|
|
67
|
+
self.torch_distributed_hook_attr = {}
|
|
68
|
+
self.torch_npu_hook_attr = {}
|
|
69
|
+
|
|
51
70
|
self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
|
|
52
71
|
|
|
53
72
|
@staticmethod
|
|
@@ -82,22 +101,73 @@ class ApiRegistry:
|
|
|
82
101
|
self.set_api_attr(ops, self.norm_inner_ops_ori_attr)
|
|
83
102
|
|
|
84
103
|
def api_set_hook_func(self):
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
104
|
+
if is_mindtorch():
|
|
105
|
+
self.set_api_attr(torch, self.torch_hook_attr)
|
|
106
|
+
self.set_api_attr(torch.Tensor, self.torch_tensor_hook_attr)
|
|
107
|
+
self.set_api_attr(torch.nn.functional, self.torch_functional_hook_attr)
|
|
108
|
+
self.set_api_attr(torch.distributed, self.torch_distributed_hook_attr)
|
|
109
|
+
self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_hook_attr)
|
|
110
|
+
self.set_api_attr(torch_npu, self.torch_npu_hook_attr)
|
|
111
|
+
else:
|
|
112
|
+
self.set_api_attr(Tensor, self.tensor_hook_attr)
|
|
113
|
+
self.set_api_attr(StubTensor, self.stub_tensor_hook_attr)
|
|
114
|
+
self.set_api_attr(ops, self.functional_hook_attr)
|
|
115
|
+
self.set_api_attr(mint, self.mint_ops_hook_attr)
|
|
116
|
+
self.set_api_attr(functional, self.mint_func_ops_hook_attr)
|
|
117
|
+
self.set_api_attr(comm_func, self.distibuted_hook_attr)
|
|
91
118
|
|
|
92
119
|
def api_set_ori_func(self):
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
120
|
+
if is_mindtorch():
|
|
121
|
+
self.set_api_attr(torch, self.torch_ori_attr)
|
|
122
|
+
self.set_api_attr(torch.Tensor, self.torch_tensor_ori_attr)
|
|
123
|
+
self.set_api_attr(torch.nn.functional, self.torch_functional_ori_attr)
|
|
124
|
+
self.set_api_attr(torch.distributed, self.torch_distributed_ori_attr)
|
|
125
|
+
self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_ori_attr)
|
|
126
|
+
self.set_api_attr(torch_npu, self.torch_npu_ori_attr)
|
|
127
|
+
else:
|
|
128
|
+
self.set_api_attr(Tensor, self.tensor_ori_attr)
|
|
129
|
+
self.set_api_attr(StubTensor, self.stub_tensor_ori_attr)
|
|
130
|
+
self.set_api_attr(ops, self.functional_ori_attr)
|
|
131
|
+
self.set_api_attr(mint, self.mint_ops_ori_attr)
|
|
132
|
+
self.set_api_attr(functional, self.mint_func_ops_ori_attr)
|
|
133
|
+
self.set_api_attr(comm_func, self.distributed_ori_attr)
|
|
99
134
|
|
|
100
135
|
def initialize_hook(self, hook):
|
|
136
|
+
setup_hooks(hook)
|
|
137
|
+
if is_mindtorch():
|
|
138
|
+
wrap_torch_api_name = get_wrap_torch_api_list()
|
|
139
|
+
self.store_ori_attr(torch,
|
|
140
|
+
wrap_torch_api_name.torch_api_names, self.torch_ori_attr)
|
|
141
|
+
self.store_ori_attr(torch.Tensor,
|
|
142
|
+
wrap_torch_api_name.tensor_api_names, self.torch_tensor_ori_attr)
|
|
143
|
+
self.store_ori_attr(torch.nn.functional,
|
|
144
|
+
wrap_torch_api_name.functional_api_names, self.torch_functional_ori_attr)
|
|
145
|
+
self.store_ori_attr(torch.distributed,
|
|
146
|
+
wrap_torch_api_name.distributed_api_names, self.torch_distributed_ori_attr)
|
|
147
|
+
self.store_ori_attr(torch_npu,
|
|
148
|
+
wrap_torch_api_name.npu_api_names, self.torch_npu_ori_attr)
|
|
149
|
+
for attr_name in dir(HOOKTorchOP):
|
|
150
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
151
|
+
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
152
|
+
self.torch_hook_attr[api_name] = getattr(HOOKTorchOP, attr_name)
|
|
153
|
+
for attr_name in dir(HOOKTorchTensor):
|
|
154
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
155
|
+
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
156
|
+
self.torch_tensor_hook_attr[api_name] = getattr(HOOKTorchTensor, attr_name)
|
|
157
|
+
for attr_name in dir(HOOKTorchFunctionalOP):
|
|
158
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
159
|
+
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
160
|
+
self.torch_functional_hook_attr[api_name] = getattr(HOOKTorchFunctionalOP, attr_name)
|
|
161
|
+
for attr_name in dir(HOOKTorchDistributedOP):
|
|
162
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
163
|
+
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
164
|
+
self.torch_distributed_hook_attr[api_name] = getattr(HOOKTorchDistributedOP, attr_name)
|
|
165
|
+
for attr_name in dir(HOOKTorchNpuOP):
|
|
166
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
167
|
+
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
168
|
+
self.torch_npu_hook_attr[api_name] = getattr(HOOKTorchNpuOP, attr_name)
|
|
169
|
+
return
|
|
170
|
+
|
|
101
171
|
wrap_api_name = get_wrap_api_list()
|
|
102
172
|
self.store_ori_attr(Tensor, wrap_api_name.tensor_api_names, self.tensor_ori_attr)
|
|
103
173
|
self.store_ori_attr(StubTensor, wrap_api_name.stub_tensor_api_names, self.stub_tensor_ori_attr)
|
|
@@ -106,7 +176,6 @@ class ApiRegistry:
|
|
|
106
176
|
self.store_ori_attr(functional, wrap_api_name.mint_nn_func_api_names, self.mint_func_ops_ori_attr)
|
|
107
177
|
self.store_ori_attr(comm_func, wrap_api_name.distributed_api_names, self.distributed_ori_attr)
|
|
108
178
|
self.store_ori_attr(ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
|
|
109
|
-
setup_hooks(hook)
|
|
110
179
|
for attr_name in dir(HOOKTensor):
|
|
111
180
|
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
112
181
|
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
2
3
|
#
|
|
3
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,45 +12,66 @@
|
|
|
11
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
13
|
# See the License for the specific language governing permissions and
|
|
13
14
|
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
15
|
|
|
16
16
|
from collections import defaultdict
|
|
17
17
|
|
|
18
18
|
from mindspore import nn
|
|
19
19
|
|
|
20
|
-
from msprobe.
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
cell_count
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
20
|
+
from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def add_cell_count(name):
|
|
24
|
+
HOOKCell.cell_count[name] += 1
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_cell_count(name):
|
|
28
|
+
return HOOKCell.cell_count[name]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def __init__(self, build_hook) -> None:
|
|
32
|
+
super(HOOKCell, self).__init__()
|
|
33
|
+
self.changed_status = False
|
|
34
|
+
self.input_kwargs = {}
|
|
35
|
+
self.prefix = ""
|
|
36
|
+
if not HOOKCell.g_stop_hook:
|
|
37
|
+
HOOKCell.g_stop_hook = True
|
|
38
|
+
self.changed_status = True
|
|
39
|
+
if hasattr(self, "prefix_api_name"):
|
|
40
|
+
self.prefix = self.prefix_api_name
|
|
41
|
+
|
|
42
|
+
self.forward_data_collected = False
|
|
43
|
+
forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = build_hook(self.prefix)
|
|
44
|
+
self.register_forward_pre_hook(forward_pre_hook)
|
|
45
|
+
self.register_forward_hook(forward_hook)
|
|
46
|
+
register_backward_hook_functions["full"](self, backward_hook)
|
|
47
|
+
register_backward_hook_functions["pre"](self, backward_pre_hook)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# 重载call,加全局标志。
|
|
51
|
+
def __call__(self, *args, **kwargs):
|
|
52
|
+
try:
|
|
53
|
+
self.input_kwargs = kwargs
|
|
54
|
+
out = super(HOOKCell, self).__call__(*args, **kwargs)
|
|
55
|
+
except Exception as e:
|
|
56
|
+
raise e
|
|
57
|
+
finally:
|
|
58
|
+
if self.changed_status:
|
|
59
|
+
self.changed_status = False
|
|
60
|
+
HOOKCell.g_stop_hook = False
|
|
61
|
+
return out
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
hook_cell_dict = {
|
|
65
|
+
"cell_count": defaultdict(int),
|
|
66
|
+
"g_stop_hook": False,
|
|
67
|
+
"add_cell_count": staticmethod(add_cell_count),
|
|
68
|
+
"get_cell_count": staticmethod(get_cell_count),
|
|
69
|
+
"__init__": __init__,
|
|
70
|
+
"__call__": __call__
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
if is_mindtorch():
|
|
74
|
+
import torch
|
|
75
|
+
HOOKCell = type("HOOKCell", (torch.nn.Module,), hook_cell_dict)
|
|
76
|
+
else:
|
|
77
|
+
HOOKCell = type("HOOKCell", (nn.Cell,), hook_cell_dict)
|
|
@@ -135,6 +135,34 @@ class PrimitiveHookService:
|
|
|
135
135
|
return tuple(hooked_outputs)
|
|
136
136
|
return out
|
|
137
137
|
|
|
138
|
+
def pre_forward_hook(primitive_name, primitive_instance, args, kwargs):
|
|
139
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
|
|
140
|
+
try:
|
|
141
|
+
self.service_instance.data_collector.forward_input_data_collect(
|
|
142
|
+
primitive_name,
|
|
143
|
+
primitive_instance,
|
|
144
|
+
os.getpid(),
|
|
145
|
+
module_input_output
|
|
146
|
+
)
|
|
147
|
+
except Exception as exception:
|
|
148
|
+
logger.error(f"This is a primitive op dump error during forward input data collection: {exception}, "
|
|
149
|
+
f"primitive_name: {primitive_name}")
|
|
150
|
+
raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
|
|
151
|
+
|
|
152
|
+
def post_forward_hook(primitive_name, primitive_instance, args, kwargs, output):
|
|
153
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
|
|
154
|
+
try:
|
|
155
|
+
self.service_instance.data_collector.forward_output_data_collect(
|
|
156
|
+
primitive_name,
|
|
157
|
+
primitive_instance,
|
|
158
|
+
os.getpid(),
|
|
159
|
+
module_input_output
|
|
160
|
+
)
|
|
161
|
+
except Exception as exception:
|
|
162
|
+
logger.error(f"This is a primitive op dump error during forward output data collection: {exception}, "
|
|
163
|
+
f"primitive_name: {primitive_name}")
|
|
164
|
+
raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
|
|
165
|
+
|
|
138
166
|
def wrapped_primitive_call(instance_self, *args, **kwargs):
|
|
139
167
|
"""
|
|
140
168
|
包装后的 primitive 调用函数,添加输入和输出的 hook。
|
|
@@ -163,27 +191,17 @@ class PrimitiveHookService:
|
|
|
163
191
|
f"primitive_name: {primitive_name}")
|
|
164
192
|
raise DumpException(DumpException.INPUT_HOOK_ERROR) from exception
|
|
165
193
|
|
|
194
|
+
forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}"
|
|
195
|
+
self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
|
|
196
|
+
|
|
197
|
+
pre_forward_hook(forward_primitive_name, instance_self, hooked_inputs, kwargs)
|
|
166
198
|
try:
|
|
167
199
|
out = origin_func(*hooked_inputs, **kwargs)
|
|
168
200
|
except Exception as exception:
|
|
169
201
|
logger.error(f"This is a primitive op dump error during function call: {exception}, "
|
|
170
202
|
f"primitive_name: {primitive_name}")
|
|
171
203
|
raise DumpException(DumpException.FUNCTION_CALL_ERROR) from exception
|
|
172
|
-
|
|
173
|
-
forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}"
|
|
174
|
-
self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
|
|
175
|
-
if self.service_instance.data_collector:
|
|
176
|
-
module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
|
|
177
|
-
try:
|
|
178
|
-
self.service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
|
|
179
|
-
os.getpid(), module_input_output)
|
|
180
|
-
except Exception as exception:
|
|
181
|
-
logger.error(f"This is a primitive op dump error during forward data collection: {exception}, "
|
|
182
|
-
f"primitive_name: {primitive_name}")
|
|
183
|
-
raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
|
|
184
|
-
|
|
185
|
-
if self.service_instance.data_collector.if_return_forward_new_output():
|
|
186
|
-
out = self.service_instance.data_collector.get_forward_new_output()
|
|
204
|
+
post_forward_hook(forward_primitive_name, instance_self, hooked_inputs, kwargs, out)
|
|
187
205
|
|
|
188
206
|
try:
|
|
189
207
|
out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
# List of ops that register hooks
|
|
17
17
|
|
|
18
|
-
|
|
18
|
+
|
|
19
19
|
ops:
|
|
20
20
|
- adaptive_avg_pool1d
|
|
21
21
|
- adaptive_avg_pool2d
|
|
@@ -85,6 +85,7 @@ ops:
|
|
|
85
85
|
- relu6
|
|
86
86
|
- celu
|
|
87
87
|
- rrelu
|
|
88
|
+
- rms_norm
|
|
88
89
|
- selu
|
|
89
90
|
- sigmoid
|
|
90
91
|
- silu
|
|
@@ -553,6 +554,7 @@ tensor:
|
|
|
553
554
|
- acos
|
|
554
555
|
- acosh
|
|
555
556
|
- add
|
|
557
|
+
- add_
|
|
556
558
|
- addbmm
|
|
557
559
|
- addcdiv
|
|
558
560
|
- addcmul
|
|
@@ -607,6 +609,7 @@ tensor:
|
|
|
607
609
|
- diff
|
|
608
610
|
- digamma
|
|
609
611
|
- div
|
|
612
|
+
- div_
|
|
610
613
|
- divide
|
|
611
614
|
- equal
|
|
612
615
|
- erf
|
|
@@ -739,6 +742,8 @@ tensor:
|
|
|
739
742
|
- square
|
|
740
743
|
- squeeze
|
|
741
744
|
- std
|
|
745
|
+
- sub
|
|
746
|
+
- sub_
|
|
742
747
|
- subtract
|
|
743
748
|
- subtract
|
|
744
749
|
- svd
|
|
@@ -983,6 +988,7 @@ mint.nn.functional:
|
|
|
983
988
|
- one_hot_ext
|
|
984
989
|
- pad
|
|
985
990
|
- relu
|
|
991
|
+
- relu_
|
|
986
992
|
- sigmoid
|
|
987
993
|
- silu
|
|
988
994
|
- softmax
|
|
@@ -1017,3 +1023,7 @@ communication.comm_func:
|
|
|
1017
1023
|
- broadcast
|
|
1018
1024
|
- gather_into_tensor
|
|
1019
1025
|
- scatter_tensor
|
|
1026
|
+
- send
|
|
1027
|
+
- recv
|
|
1028
|
+
- isend
|
|
1029
|
+
- irecv
|