mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
- mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +16 -21
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +185 -11
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +4 -0
- msprobe/core/common/utils.py +42 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +381 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +199 -69
- msprobe/core/data_dump/data_collector.py +100 -25
- msprobe/core/data_dump/data_processor/base.py +130 -28
- msprobe/core/data_dump/data_processor/factory.py +8 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
- msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
- msprobe/core/data_dump/json_writer.py +54 -8
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +121 -17
- msprobe/docs/02.config_introduction.md +18 -16
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +107 -58
- msprobe/docs/06.data_dump_MindSpore.md +95 -34
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +310 -220
- msprobe/docs/21.visualization_PyTorch.md +125 -35
- msprobe/docs/22.visualization_MindSpore.md +149 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +525 -0
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +11 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +68 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +28 -2
- msprobe/mindspore/debugger/precision_debugger.py +100 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +13 -3
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +347 -107
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -7
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +116 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +42 -17
- msprobe/pytorch/debugger/precision_debugger.py +56 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +98 -28
- msprobe/pytorch/monitor/csv2tb.py +164 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +543 -318
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +76 -56
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +84 -48
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +264 -115
- msprobe/visualization/builder/graph_builder.py +93 -10
- msprobe/visualization/builder/msprobe_adapter.py +30 -6
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +15 -19
- msprobe/visualization/graph/distributed_analyzer.py +395 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +100 -27
- msprobe/visualization/utils.py +24 -31
- mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
- msprobe/pytorch/functional/module_dump.py +0 -84
- msprobe/pytorch/module_processer.py +0 -150
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
msprobe/pytorch/service.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -15,24 +15,24 @@
|
|
|
15
15
|
|
|
16
16
|
import functools
|
|
17
17
|
import os
|
|
18
|
-
from collections import namedtuple
|
|
18
|
+
from collections import namedtuple, defaultdict
|
|
19
19
|
|
|
20
20
|
import torch
|
|
21
21
|
from msprobe.core.common.const import Const
|
|
22
|
-
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
22
|
+
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
23
23
|
from msprobe.core.common.file_utils import create_directory
|
|
24
|
-
from msprobe.core.common.utils import print_tools_ends_info
|
|
24
|
+
from msprobe.core.common.utils import print_tools_ends_info, DumpPathAggregation
|
|
25
25
|
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
26
26
|
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
27
27
|
from msprobe.core.data_dump.scope import BaseScope
|
|
28
28
|
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
29
29
|
from msprobe.pytorch.common.log import logger
|
|
30
|
-
from msprobe.pytorch.common.utils import get_rank_if_initialized
|
|
30
|
+
from msprobe.pytorch.common.utils import get_rank_if_initialized, is_recomputation
|
|
31
31
|
from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json
|
|
32
|
-
from msprobe.pytorch.
|
|
32
|
+
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
|
|
33
33
|
from msprobe.pytorch.hook_module.api_registry import api_register
|
|
34
34
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
35
|
-
from msprobe.pytorch.
|
|
35
|
+
from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
|
|
36
36
|
|
|
37
37
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
38
38
|
if torch_version_above_or_equal_2:
|
|
@@ -48,100 +48,206 @@ class Service:
|
|
|
48
48
|
self.data_collector = build_data_collector(config)
|
|
49
49
|
self.module_processor = ModuleProcesser(self.data_collector.scope)
|
|
50
50
|
self.switch = False
|
|
51
|
+
self.inner_switch = False
|
|
51
52
|
self.current_iter = 0
|
|
52
53
|
self.first_start = True
|
|
53
54
|
self.current_rank = None
|
|
54
55
|
self.dump_iter_dir = None
|
|
55
56
|
self.should_stop_service = False
|
|
56
57
|
self.attl = None
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
@staticmethod
|
|
64
|
-
def is_registered_backward_hook(module):
|
|
65
|
-
if hasattr(module, '_backward_hooks') and \
|
|
66
|
-
len(module._backward_hooks) > 0 and \
|
|
67
|
-
module._is_full_backward_hook is False:
|
|
68
|
-
return True
|
|
69
|
-
return False
|
|
70
|
-
|
|
71
|
-
def check_register_full_backward_hook(self, module):
|
|
72
|
-
if self.is_registered_backward_hook(module):
|
|
73
|
-
module._backward_hooks.clear()
|
|
74
|
-
module._is_full_backward_hook = None
|
|
75
|
-
logger.warning("Found deprecated backward hooks. Removing them and switching to full backward hooks.")
|
|
58
|
+
self.params_grad_info = {}
|
|
59
|
+
self.hook_handle_dict = {}
|
|
60
|
+
# 提前注册,确保注册尽可能多的API hook
|
|
61
|
+
self.register_api_hook()
|
|
62
|
+
self.init_for_debug_level()
|
|
76
63
|
|
|
77
64
|
def build_hook(self, module_type, name):
|
|
78
65
|
def pre_hook(api_or_module_name, module, args, kwargs):
|
|
79
|
-
if not self.should_execute_hook():
|
|
66
|
+
if not self.should_execute_hook(module_type, module, True):
|
|
80
67
|
return args, kwargs
|
|
68
|
+
is_recompute = is_recomputation()
|
|
81
69
|
|
|
70
|
+
self.inner_switch = True
|
|
82
71
|
if module_type == BaseScope.Module_Type_Module:
|
|
83
|
-
api_or_module_name = module.mindstudio_reserved_name
|
|
72
|
+
api_or_module_name = module.mindstudio_reserved_name[-1]
|
|
73
|
+
else:
|
|
74
|
+
module.forward_data_collected = True
|
|
75
|
+
HOOKModule.add_module_count(name)
|
|
84
76
|
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
85
77
|
|
|
86
78
|
if self.config.online_run_ut:
|
|
79
|
+
self.inner_switch = False
|
|
87
80
|
return None, None
|
|
88
81
|
if self.data_collector:
|
|
89
82
|
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
|
|
90
|
-
self.data_collector.
|
|
83
|
+
self.data_collector.forward_input_data_collect(
|
|
84
|
+
api_or_module_name,
|
|
85
|
+
module,
|
|
86
|
+
pid,
|
|
87
|
+
module_input_output,
|
|
88
|
+
is_recompute
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
self.inner_switch = False
|
|
91
92
|
return args, kwargs
|
|
92
93
|
|
|
94
|
+
def grad_hook(module, ori_name, param_name):
|
|
95
|
+
def hook_fn(grad):
|
|
96
|
+
if not self.should_execute_hook(module_type, module, False):
|
|
97
|
+
return grad
|
|
98
|
+
self.inner_switch = True
|
|
99
|
+
self.data_collector.params_data_collect(ori_name, param_name, pid, grad)
|
|
100
|
+
self.inner_switch = False
|
|
101
|
+
return grad
|
|
102
|
+
|
|
103
|
+
return hook_fn
|
|
104
|
+
|
|
105
|
+
def register_param_hook(ori_name, module, params_dict):
|
|
106
|
+
'''
|
|
107
|
+
注册参数hook
|
|
108
|
+
'''
|
|
109
|
+
# data_mode为forward时,不注册参数hook
|
|
110
|
+
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
111
|
+
for param_name, param in params_dict.items():
|
|
112
|
+
if param.requires_grad:
|
|
113
|
+
name = ori_name + Const.SEP + param_name
|
|
114
|
+
old_handle = self.hook_handle_dict.get(name)
|
|
115
|
+
if old_handle and hasattr(old_handle, "remove"):
|
|
116
|
+
old_handle.remove()
|
|
117
|
+
handle = param.register_hook(grad_hook(module, ori_name, param_name))
|
|
118
|
+
self.hook_handle_dict[name] = handle
|
|
119
|
+
|
|
120
|
+
def init_params_grad_info(module, params_dict):
|
|
121
|
+
'''
|
|
122
|
+
初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
|
|
123
|
+
'''
|
|
124
|
+
if not params_dict:
|
|
125
|
+
return
|
|
126
|
+
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
127
|
+
grad_name = module.params_grad_name if hasattr(module, 'params_grad_name') else None
|
|
128
|
+
# 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
|
|
129
|
+
if not self.params_grad_info.get(grad_name):
|
|
130
|
+
data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
|
|
131
|
+
# 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
|
|
132
|
+
if data_info.get(grad_name):
|
|
133
|
+
# 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
|
|
134
|
+
self.data_collector.handle_data(grad_name, data_info,
|
|
135
|
+
flush=self.data_collector.data_processor.is_terminated)
|
|
136
|
+
# 记录当前模块的参数梯度信息已占位
|
|
137
|
+
self.params_grad_info[grad_name] = True
|
|
138
|
+
|
|
93
139
|
def forward_hook(api_or_module_name, module, args, kwargs, output):
|
|
94
|
-
if not self.should_execute_hook():
|
|
140
|
+
if not self.should_execute_hook(module_type, module, True):
|
|
95
141
|
return None
|
|
142
|
+
is_recompute = is_recomputation()
|
|
96
143
|
|
|
97
|
-
|
|
98
|
-
api_or_module_name = module.mindstudio_reserved_name
|
|
99
|
-
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
100
|
-
|
|
144
|
+
self.inner_switch = True
|
|
101
145
|
if self.config.online_run_ut:
|
|
146
|
+
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
102
147
|
if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
|
|
103
148
|
return None
|
|
104
|
-
api_data = ApiData(
|
|
149
|
+
api_data = ApiData(
|
|
150
|
+
api_or_module_name[:-len(Const.FORWARD_NAME_SUFFIX)],
|
|
151
|
+
args,
|
|
152
|
+
kwargs,
|
|
153
|
+
output,
|
|
154
|
+
self.current_iter,
|
|
155
|
+
self.current_rank
|
|
156
|
+
)
|
|
105
157
|
self.attl_send(api_data)
|
|
158
|
+
self.inner_switch = False
|
|
106
159
|
return None
|
|
107
160
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
161
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
|
|
162
|
+
if module_type == BaseScope.Module_Type_Module:
|
|
163
|
+
api_or_module_name = module.mindstudio_reserved_name[-1]
|
|
164
|
+
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
165
|
+
params_dict = {}
|
|
166
|
+
if self.config.task != Const.STRUCTURE:
|
|
167
|
+
params_dict = {
|
|
168
|
+
key.split(Const.SEP)[-1]: value
|
|
169
|
+
for key, value in module.named_parameters(recurse=False)
|
|
170
|
+
}
|
|
171
|
+
setattr(module_input_output, Const.PARAMS, params_dict)
|
|
172
|
+
# 判断是否需要注册参数hook
|
|
173
|
+
if params_dict:
|
|
174
|
+
ori_name = api_or_module_name.rsplit(Const.SEP, 2)[0]
|
|
175
|
+
grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
|
|
176
|
+
# 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
|
|
177
|
+
setattr(module, 'params_grad_name', grad_name)
|
|
178
|
+
register_param_hook(ori_name, module, params_dict)
|
|
179
|
+
self.data_collector.forward_data_collect(
|
|
180
|
+
api_or_module_name,
|
|
181
|
+
module,
|
|
182
|
+
pid,
|
|
183
|
+
module_input_output,
|
|
184
|
+
is_recompute
|
|
185
|
+
)
|
|
186
|
+
init_params_grad_info(module, params_dict)
|
|
187
|
+
else:
|
|
188
|
+
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
189
|
+
self.data_collector.forward_output_data_collect(
|
|
190
|
+
api_or_module_name,
|
|
191
|
+
module,
|
|
192
|
+
pid,
|
|
193
|
+
module_input_output,
|
|
194
|
+
is_recompute
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
if self.data_collector.if_return_forward_new_output():
|
|
198
|
+
forward_new_output = self.data_collector.get_forward_new_output()
|
|
199
|
+
self.inner_switch = False
|
|
200
|
+
return forward_new_output
|
|
201
|
+
self.inner_switch = False
|
|
113
202
|
return output
|
|
114
203
|
|
|
115
204
|
def forward_hook_torch_version_below_2(api_or_module_name, module, args, output):
|
|
116
205
|
return forward_hook(api_or_module_name, module, args, {}, output)
|
|
117
206
|
|
|
118
207
|
def backward_hook(api_or_module_name, module, grad_input, grad_output):
|
|
119
|
-
if not self.should_execute_hook():
|
|
208
|
+
if not self.should_execute_hook(module_type, module, False):
|
|
120
209
|
return
|
|
210
|
+
is_recompute = is_recomputation()
|
|
121
211
|
|
|
212
|
+
self.inner_switch = True
|
|
122
213
|
if module_type == BaseScope.Module_Type_Module:
|
|
123
|
-
api_or_module_name = module.mindstudio_reserved_name
|
|
214
|
+
api_or_module_name = module.mindstudio_reserved_name[-1]
|
|
124
215
|
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
125
216
|
|
|
126
217
|
if self.config.online_run_ut:
|
|
218
|
+
self.inner_switch = False
|
|
127
219
|
return
|
|
128
220
|
|
|
129
221
|
if self.data_collector:
|
|
130
222
|
# 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序
|
|
131
223
|
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
|
|
132
|
-
self.data_collector.backward_data_collect(
|
|
224
|
+
self.data_collector.backward_data_collect(
|
|
225
|
+
api_or_module_name,
|
|
226
|
+
module,
|
|
227
|
+
pid,
|
|
228
|
+
module_input_output,
|
|
229
|
+
is_recompute
|
|
230
|
+
)
|
|
231
|
+
self.inner_switch = False
|
|
133
232
|
|
|
134
233
|
pid = os.getpid()
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
234
|
+
full_forward_name = None
|
|
235
|
+
full_backward_name = None
|
|
236
|
+
if module_type == BaseScope.Module_Type_API:
|
|
237
|
+
full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD
|
|
238
|
+
full_backward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.BACKWARD
|
|
239
|
+
pre_forward_hook_fn = functools.partial(pre_hook, full_forward_name)
|
|
240
|
+
forward_hook_fn = functools.partial(forward_hook, full_forward_name)
|
|
241
|
+
backward_hook_fn = functools.partial(backward_hook, full_backward_name)
|
|
242
|
+
forward_hook_torch_version_below_2_fn = functools.partial(
|
|
243
|
+
forward_hook_torch_version_below_2,
|
|
244
|
+
full_forward_name
|
|
245
|
+
)
|
|
142
246
|
return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
|
|
143
247
|
|
|
144
|
-
def start(self, model
|
|
248
|
+
def start(self, model):
|
|
249
|
+
if self.config.level == Const.LEVEL_DEBUG:
|
|
250
|
+
return
|
|
145
251
|
if self.need_stop_service():
|
|
146
252
|
return
|
|
147
253
|
|
|
@@ -155,10 +261,10 @@ class Service:
|
|
|
155
261
|
|
|
156
262
|
if self.config.rank and self.current_rank not in self.config.rank:
|
|
157
263
|
return
|
|
158
|
-
self.
|
|
264
|
+
self.register_module_hook()
|
|
265
|
+
if self.config.level == Const.LEVEL_MIX:
|
|
266
|
+
register_optimizer_hook(self.data_collector)
|
|
159
267
|
self.first_start = False
|
|
160
|
-
if api_origin:
|
|
161
|
-
api_register.api_modularity()
|
|
162
268
|
if self.config.online_run_ut and torch_version_above_or_equal_2:
|
|
163
269
|
run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute)
|
|
164
270
|
self.switch = True
|
|
@@ -168,32 +274,39 @@ class Service:
|
|
|
168
274
|
logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
|
|
169
275
|
|
|
170
276
|
def stop(self):
|
|
171
|
-
if self.
|
|
277
|
+
if self.config.level == Const.LEVEL_DEBUG:
|
|
172
278
|
return
|
|
173
|
-
if self.
|
|
279
|
+
if self.should_stop_service:
|
|
174
280
|
return
|
|
175
281
|
if self.config.step and self.current_iter not in self.config.step:
|
|
176
282
|
return
|
|
177
283
|
if self.config.rank and self.current_rank not in self.config.rank:
|
|
178
284
|
return
|
|
179
285
|
self.switch = False
|
|
286
|
+
if self.config.level == Const.LEVEL_L2:
|
|
287
|
+
return
|
|
180
288
|
if self.config.online_run_ut and torch_version_above_or_equal_2:
|
|
181
289
|
run_ut_dispatch(self.attl, False, self.config.online_run_ut_recompute)
|
|
182
290
|
return
|
|
291
|
+
if self.config.async_dump:
|
|
292
|
+
self.data_collector.fill_stack_tensor_data()
|
|
293
|
+
if self.config.task == Const.TENSOR:
|
|
294
|
+
self.data_collector.data_processor.dump_async_data()
|
|
183
295
|
self.data_collector.write_json()
|
|
184
296
|
|
|
185
297
|
def step(self):
|
|
298
|
+
if self.config.level == Const.LEVEL_DEBUG:
|
|
299
|
+
return
|
|
186
300
|
if self.should_stop_service:
|
|
187
301
|
return
|
|
302
|
+
if self.config.async_dump:
|
|
303
|
+
self.data_collector.fill_stack_tensor_data()
|
|
304
|
+
if self.config.task == Const.TENSOR:
|
|
305
|
+
self.data_collector.data_processor.dump_async_data()
|
|
306
|
+
self.data_collector.write_json()
|
|
188
307
|
self.current_iter += 1
|
|
189
308
|
self.data_collector.update_iter(self.current_iter)
|
|
190
|
-
|
|
191
|
-
ModuleProcesser.reset_module_stats()
|
|
192
|
-
HOOKModule.reset_module_stats()
|
|
193
|
-
self.data_collector.data_writer.reset_cache()
|
|
194
|
-
|
|
195
|
-
if self.config.level == Const.LEVEL_L2:
|
|
196
|
-
self.data_collector.data_processor.reset_status()
|
|
309
|
+
self.reset_status()
|
|
197
310
|
|
|
198
311
|
def need_stop_service(self):
|
|
199
312
|
if self.should_stop_service:
|
|
@@ -204,8 +317,6 @@ class Service:
|
|
|
204
317
|
if self.config.online_run_ut:
|
|
205
318
|
# send stop signal if online_run_ut
|
|
206
319
|
self.attl_stop()
|
|
207
|
-
if self.config.level in [Const.LEVEL_L1, Const.LEVEL_L2, Const.LEVEL_MIX]:
|
|
208
|
-
api_register.api_originality()
|
|
209
320
|
self.switch = False
|
|
210
321
|
self.should_stop_service = True
|
|
211
322
|
print_tools_ends_info()
|
|
@@ -214,10 +325,18 @@ class Service:
|
|
|
214
325
|
return True
|
|
215
326
|
return False
|
|
216
327
|
|
|
217
|
-
def should_execute_hook(self):
|
|
218
|
-
|
|
328
|
+
def should_execute_hook(self, hook_type, module, is_forward):
|
|
329
|
+
is_module_hook = hook_type == BaseScope.Module_Type_Module
|
|
330
|
+
if is_module_hook and not self.switch:
|
|
331
|
+
return False
|
|
332
|
+
elif not is_module_hook and is_forward and not self.switch:
|
|
333
|
+
return False
|
|
334
|
+
elif not is_module_hook and not is_forward and not module.forward_data_collected:
|
|
335
|
+
return False
|
|
336
|
+
|
|
337
|
+
if self.inner_switch:
|
|
219
338
|
return False
|
|
220
|
-
if self.data_collector
|
|
339
|
+
if not self.data_collector or self.data_collector.data_processor.is_terminated:
|
|
221
340
|
return False
|
|
222
341
|
return True
|
|
223
342
|
|
|
@@ -239,55 +358,28 @@ class Service:
|
|
|
239
358
|
else:
|
|
240
359
|
dump_data_dir = None
|
|
241
360
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP + \
|
|
259
|
-
module.__class__.__name__ + Const.SEP
|
|
260
|
-
|
|
261
|
-
pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.build_hook(
|
|
262
|
-
BaseScope.Module_Type_Module, prefix)
|
|
263
|
-
if torch_version_above_or_equal_2:
|
|
264
|
-
module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
265
|
-
else:
|
|
266
|
-
self.check_register_full_backward_hook(module)
|
|
267
|
-
module.register_full_backward_hook(
|
|
268
|
-
self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
269
|
-
module.register_forward_hook(forward_hook_torch_version_below_2)
|
|
270
|
-
self.check_register_full_backward_hook(module)
|
|
271
|
-
module.register_full_backward_hook(backward_hook)
|
|
272
|
-
|
|
273
|
-
module.register_forward_pre_hook(
|
|
274
|
-
self.module_processor.node_hook(prefix + Const.FORWARD, Const.START))
|
|
275
|
-
module.register_forward_hook(
|
|
276
|
-
self.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
|
|
277
|
-
if torch_version_above_or_equal_2:
|
|
278
|
-
module.register_full_backward_pre_hook(
|
|
279
|
-
self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
|
|
280
|
-
self.check_register_full_backward_hook(module)
|
|
281
|
-
module.register_full_backward_hook(
|
|
282
|
-
self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
283
|
-
|
|
284
|
-
if self.config.level in ["mix", "L1", "L2"]:
|
|
285
|
-
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API),
|
|
286
|
-
self.config.online_run_ut)
|
|
361
|
+
dump_path_aggregation = DumpPathAggregation()
|
|
362
|
+
dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
|
|
363
|
+
dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
|
|
364
|
+
dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
|
|
365
|
+
dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
|
|
366
|
+
dump_path_aggregation.free_benchmark_file_path = os.path.join(dump_dir, "free_benchmark.csv")
|
|
367
|
+
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
368
|
+
self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
|
|
369
|
+
|
|
370
|
+
def register_api_hook(self):
|
|
371
|
+
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
|
|
372
|
+
logger.info_on_rank_0(f"The api {self.config.task} hook function is successfully mounted to the model.")
|
|
373
|
+
api_register.initialize_hook(
|
|
374
|
+
functools.partial(self.build_hook, BaseScope.Module_Type_API),
|
|
375
|
+
self.config.online_run_ut
|
|
376
|
+
)
|
|
287
377
|
api_register.api_modularity()
|
|
288
378
|
|
|
289
|
-
|
|
290
|
-
|
|
379
|
+
def register_module_hook(self):
|
|
380
|
+
if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]:
|
|
381
|
+
logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.")
|
|
382
|
+
self.module_processor.register_module_hook(self.model, self.build_hook)
|
|
291
383
|
|
|
292
384
|
def attl_init(self):
|
|
293
385
|
if self.config.online_run_ut:
|
|
@@ -319,3 +411,60 @@ class Service:
|
|
|
319
411
|
elif self.attl.socket_manager is not None:
|
|
320
412
|
logger.info(f"pid: {os.getpid()} finished, start send STOP signal.")
|
|
321
413
|
self.attl.socket_manager.send_stop_signal()
|
|
414
|
+
|
|
415
|
+
def reset_status(self):
|
|
416
|
+
ModuleProcesser.reset_module_stats()
|
|
417
|
+
HOOKModule.reset_module_stats()
|
|
418
|
+
self.data_collector.reset_status()
|
|
419
|
+
self.params_grad_info.clear()
|
|
420
|
+
|
|
421
|
+
if self.config.level == Const.LEVEL_L2:
|
|
422
|
+
self.data_collector.data_processor.reset_status()
|
|
423
|
+
return
|
|
424
|
+
if self.config.step and self.current_iter not in self.config.step:
|
|
425
|
+
return
|
|
426
|
+
if self.config.rank and self.current_rank not in self.config.rank:
|
|
427
|
+
return
|
|
428
|
+
|
|
429
|
+
def init_for_debug_level(self):
|
|
430
|
+
if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]):
|
|
431
|
+
return
|
|
432
|
+
try:
|
|
433
|
+
self.current_rank = get_rank_if_initialized()
|
|
434
|
+
except DistributedNotInitializedError:
|
|
435
|
+
self.current_rank = None
|
|
436
|
+
|
|
437
|
+
# dir: dump_path -- rank{} -- debug.json
|
|
438
|
+
self.dump_iter_dir = self.config.dump_path
|
|
439
|
+
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
440
|
+
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
441
|
+
create_directory(dump_dir)
|
|
442
|
+
if self.config.task in self.data_collector.tasks_need_tensor_data:
|
|
443
|
+
dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
|
|
444
|
+
create_directory(dump_data_dir)
|
|
445
|
+
else:
|
|
446
|
+
dump_data_dir = None
|
|
447
|
+
|
|
448
|
+
dump_path_aggregation = DumpPathAggregation()
|
|
449
|
+
dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
|
|
450
|
+
dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
|
|
451
|
+
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
452
|
+
self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
|
|
453
|
+
|
|
454
|
+
self.debug_variable_counter = defaultdict(int)
|
|
455
|
+
|
|
456
|
+
def save(self, variable, name, save_backward):
|
|
457
|
+
if self.config.level != Const.LEVEL_DEBUG:
|
|
458
|
+
return
|
|
459
|
+
count = self.debug_variable_counter[name]
|
|
460
|
+
self.debug_variable_counter[name] += 1
|
|
461
|
+
|
|
462
|
+
name_with_count = f"{name}.{count}"
|
|
463
|
+
grad_name_with_count = f"{name}_grad.{count}"
|
|
464
|
+
|
|
465
|
+
# forward save
|
|
466
|
+
self.data_collector.debug_data_collect_forward(variable, name_with_count)
|
|
467
|
+
|
|
468
|
+
# backward save
|
|
469
|
+
if save_backward:
|
|
470
|
+
self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)
|