mindstudio-probe 8.1.1__py3-none-any.whl → 8.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/METADATA +1 -1
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/RECORD +95 -94
- msprobe/core/common/const.py +3 -0
- msprobe/core/common/file_utils.py +45 -5
- msprobe/core/common/utils.py +117 -13
- msprobe/core/common_config.py +15 -1
- msprobe/core/compare/acc_compare.py +21 -9
- msprobe/core/compare/compare_cli.py +10 -2
- msprobe/core/compare/merge_result/merge_result.py +1 -1
- msprobe/core/compare/utils.py +8 -2
- msprobe/core/config_check/checkers/base_checker.py +2 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +5 -4
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +4 -1
- msprobe/core/config_check/config_check_cli.py +1 -1
- msprobe/core/config_check/config_checker.py +1 -2
- msprobe/core/data_dump/data_collector.py +4 -1
- msprobe/core/data_dump/data_processor/mindspore_processor.py +23 -1
- msprobe/core/data_dump/data_processor/pytorch_processor.py +3 -25
- msprobe/core/debugger/precision_debugger.py +13 -8
- msprobe/core/hook_manager.py +112 -82
- msprobe/core/monitor/utils.py +338 -0
- msprobe/core/service.py +2 -1
- msprobe/core/single_save/single_comparator.py +5 -3
- msprobe/docs/01.installation.md +1 -0
- msprobe/docs/05.data_dump_PyTorch.md +4 -4
- msprobe/docs/07.accuracy_checker_PyTorch.md +14 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +13 -11
- msprobe/docs/10.accuracy_compare_PyTorch.md +3 -1
- msprobe/docs/11.accuracy_compare_MindSpore.md +4 -2
- msprobe/docs/12.overflow_check_PyTorch.md +3 -2
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/14.data_parse_PyTorch.md +35 -32
- msprobe/docs/21.visualization_PyTorch.md +9 -8
- msprobe/docs/22.visualization_MindSpore.md +1 -0
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/24.code_mapping_Mindspore.md +6 -5
- msprobe/docs/31.config_check.md +15 -5
- msprobe/docs/33.generate_operator_MindSpore.md +2 -2
- msprobe/docs/34.RL_collect.md +18 -9
- msprobe/docs/35.nan_analyze.md +4 -3
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +29 -1
- msprobe/mindspore/cell_processor.py +35 -14
- msprobe/mindspore/code_mapping/bind.py +23 -4
- msprobe/mindspore/code_mapping/graph_parser.py +6 -4
- msprobe/mindspore/common/utils.py +3 -0
- msprobe/mindspore/compare/common_dir_compare.py +32 -12
- msprobe/mindspore/compare/ms_graph_compare.py +7 -2
- msprobe/mindspore/compare/utils.py +9 -1
- msprobe/mindspore/debugger/debugger_config.py +13 -11
- msprobe/mindspore/debugger/precision_debugger.py +67 -45
- msprobe/mindspore/dump/dump_tool_factory.py +2 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +14 -9
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +12 -7
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +27 -13
- msprobe/mindspore/dump/jit_dump.py +6 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +13 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +6 -5
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -0
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/monitor/common_func.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +3 -3
- msprobe/mindspore/monitor/utils.py +0 -252
- msprobe/mindspore/ms_config.py +0 -1
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/nan_analyze/graph.py +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +15 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +1 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +1 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -4
- msprobe/pytorch/common/utils.py +0 -16
- msprobe/pytorch/compare/pt_compare.py +5 -0
- msprobe/pytorch/debugger/debugger_config.py +12 -5
- msprobe/pytorch/debugger/precision_debugger.py +8 -1
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +1 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +44 -13
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +2 -0
- msprobe/pytorch/hook_module/hook_module.py +9 -9
- msprobe/pytorch/hook_module/pt_hook_manager.py +7 -7
- msprobe/pytorch/monitor/csv2tb.py +3 -10
- msprobe/pytorch/monitor/features.py +5 -0
- msprobe/pytorch/monitor/module_hook.py +6 -7
- msprobe/pytorch/monitor/module_metric.py +0 -3
- msprobe/pytorch/monitor/optimizer_collect.py +1 -1
- msprobe/pytorch/monitor/utils.py +1 -317
- msprobe/pytorch/online_dispatch/dispatch.py +1 -1
- msprobe/pytorch/online_dispatch/dump_compare.py +7 -1
- msprobe/pytorch/parse_tool/lib/utils.py +2 -4
- msprobe/visualization/graph_service.py +1 -1
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/top_level.txt +0 -0
|
@@ -18,7 +18,7 @@ import os
|
|
|
18
18
|
from msprobe.core.common.const import Const, FileCheckConst, MsgConst
|
|
19
19
|
from msprobe.core.common.exceptions import MsprobeException
|
|
20
20
|
from msprobe.core.common.file_utils import FileChecker, load_json
|
|
21
|
-
from msprobe.core.common.utils import get_real_step_or_rank, check_init_step
|
|
21
|
+
from msprobe.core.common.utils import get_real_step_or_rank, check_init_step, ThreadSafe
|
|
22
22
|
from msprobe.core.common_config import CommonConfig
|
|
23
23
|
|
|
24
24
|
|
|
@@ -27,13 +27,14 @@ class BasePrecisionDebugger:
|
|
|
27
27
|
tasks_not_need_debugger = [Const.GRAD_PROBE]
|
|
28
28
|
|
|
29
29
|
def __new__(cls, *args, **kwargs):
|
|
30
|
-
if cls._instance
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
30
|
+
if not cls._instance:
|
|
31
|
+
with ThreadSafe():
|
|
32
|
+
if not cls._instance:
|
|
33
|
+
cls._instance = super(BasePrecisionDebugger, cls).__new__(cls)
|
|
34
|
+
cls._instance.config = None
|
|
35
|
+
cls._instance.initialized = False
|
|
36
|
+
cls.service = None
|
|
37
|
+
cls.first_start = False
|
|
37
38
|
return cls._instance
|
|
38
39
|
|
|
39
40
|
def __init__(
|
|
@@ -83,11 +84,13 @@ class BasePrecisionDebugger:
|
|
|
83
84
|
raise NotImplementedError("Subclass must implement _get_task_config")
|
|
84
85
|
|
|
85
86
|
@classmethod
|
|
87
|
+
@ThreadSafe.synchronized
|
|
86
88
|
def forward_backward_dump_end(cls):
|
|
87
89
|
instance = cls._instance
|
|
88
90
|
instance.stop()
|
|
89
91
|
|
|
90
92
|
@classmethod
|
|
93
|
+
@ThreadSafe.synchronized
|
|
91
94
|
def set_init_step(cls, step):
|
|
92
95
|
instance = cls._instance
|
|
93
96
|
if not instance:
|
|
@@ -97,6 +100,7 @@ class BasePrecisionDebugger:
|
|
|
97
100
|
instance.service.loop = 0
|
|
98
101
|
|
|
99
102
|
@classmethod
|
|
103
|
+
@ThreadSafe.synchronized
|
|
100
104
|
def register_custom_api(cls, module, api, api_prefix=None):
|
|
101
105
|
if not api_prefix:
|
|
102
106
|
api_prefix = getattr(module, "__name__", "Custom")
|
|
@@ -112,6 +116,7 @@ class BasePrecisionDebugger:
|
|
|
112
116
|
instance.service.register_custom_api(module, api, api_prefix)
|
|
113
117
|
|
|
114
118
|
@classmethod
|
|
119
|
+
@ThreadSafe.synchronized
|
|
115
120
|
def restore_custom_api(cls, module, api):
|
|
116
121
|
if not hasattr(module, api):
|
|
117
122
|
raise MsprobeException(
|
msprobe/core/hook_manager.py
CHANGED
|
@@ -13,12 +13,14 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
from abc import ABC, abstractmethod
|
|
18
16
|
import os
|
|
17
|
+
import threading
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
from collections import defaultdict
|
|
19
20
|
|
|
21
|
+
from msprobe.core.common.log import logger
|
|
20
22
|
from msprobe.core.common.runtime import Runtime
|
|
21
|
-
from msprobe.core.common.utils import Const
|
|
23
|
+
from msprobe.core.common.utils import Const, ThreadSafe
|
|
22
24
|
from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs)
|
|
23
25
|
|
|
24
26
|
|
|
@@ -31,7 +33,7 @@ class HookSet:
|
|
|
31
33
|
|
|
32
34
|
|
|
33
35
|
class BaseHookManager(ABC):
|
|
34
|
-
inner_switch =
|
|
36
|
+
inner_switch = defaultdict(bool)
|
|
35
37
|
hook_handle_dict = {}
|
|
36
38
|
params_grad_info = {}
|
|
37
39
|
|
|
@@ -86,7 +88,7 @@ class BaseHookManager(ABC):
|
|
|
86
88
|
grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
|
|
87
89
|
# 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
|
|
88
90
|
setattr(module, 'params_grad_name', grad_name)
|
|
89
|
-
|
|
91
|
+
# data_mode为forward时,不注册参数hook
|
|
90
92
|
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
91
93
|
for param_name, param in params_dict.items():
|
|
92
94
|
if param.requires_grad:
|
|
@@ -116,7 +118,7 @@ class BaseHookManager(ABC):
|
|
|
116
118
|
# 记录当前模块的参数梯度信息已占位
|
|
117
119
|
BaseHookManager.params_grad_info[grad_name] = True
|
|
118
120
|
|
|
119
|
-
def _should_execute_hook(self, hook_type, module, is_forward):
|
|
121
|
+
def _should_execute_hook(self, hook_type, module, is_forward, tid):
|
|
120
122
|
is_module_hook = hook_type == Const.MODULE
|
|
121
123
|
if is_module_hook and not Runtime.is_running:
|
|
122
124
|
return False
|
|
@@ -124,7 +126,7 @@ class BaseHookManager(ABC):
|
|
|
124
126
|
return False
|
|
125
127
|
elif not is_module_hook and not is_forward and not module.forward_data_collected:
|
|
126
128
|
return False
|
|
127
|
-
if BaseHookManager.inner_switch:
|
|
129
|
+
if BaseHookManager.inner_switch[tid]:
|
|
128
130
|
return False
|
|
129
131
|
if not self.data_collector or self.data_collector.data_processor.is_terminated:
|
|
130
132
|
return False
|
|
@@ -132,111 +134,139 @@ class BaseHookManager(ABC):
|
|
|
132
134
|
|
|
133
135
|
def _build_grad_hook(self, module, ori_name, param_name):
|
|
134
136
|
def hook_fn(grad):
|
|
135
|
-
|
|
137
|
+
tid = threading.get_ident()
|
|
138
|
+
if not self._should_execute_hook(Const.MODULE, module, False, tid):
|
|
136
139
|
return
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
+
with ThreadSafe():
|
|
141
|
+
BaseHookManager.inner_switch[tid] = True
|
|
142
|
+
self.data_collector.params_data_collect(ori_name, param_name, self._pid, grad)
|
|
143
|
+
BaseHookManager.inner_switch[tid] = False
|
|
140
144
|
return
|
|
145
|
+
|
|
141
146
|
return hook_fn
|
|
142
147
|
|
|
143
148
|
def _build_forward_pre_hook(self, hook_type, full_name, api_name):
|
|
144
149
|
def forward_pre_hook(module, args, kwargs=None):
|
|
150
|
+
"""
|
|
151
|
+
为确保多线程场景下 L1 级别数据采集的正确性,每个封装后的 API 的 init 方法和 forward_pre_hook 需要确保在一个线程内连续完成,
|
|
152
|
+
因此在 API 的 init 方法执行 ThreadSafe.acquire() 加锁操作,
|
|
153
|
+
并且在 API 的 forward_pre_hook 方法执行 ThreadSafe.release() 释放锁操作。
|
|
154
|
+
"""
|
|
145
155
|
if hook_type == Const.MODULE:
|
|
146
156
|
return
|
|
147
|
-
|
|
157
|
+
|
|
158
|
+
tid = threading.get_ident()
|
|
159
|
+
if not self._should_execute_hook(hook_type, module, True, tid):
|
|
160
|
+
ThreadSafe.release()
|
|
148
161
|
return
|
|
162
|
+
|
|
163
|
+
module.forward_data_collected = True
|
|
164
|
+
self._add_count(api_name)
|
|
165
|
+
if getattr(self.config, "online_run_ut", False):
|
|
166
|
+
ThreadSafe.release()
|
|
167
|
+
return
|
|
168
|
+
|
|
169
|
+
BaseHookManager.inner_switch[tid] = True
|
|
149
170
|
if kwargs is None:
|
|
150
171
|
kwargs = module.msprobe_input_kwargs if hasattr(module, 'msprobe_input_kwargs') else {}
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
self._add_count(api_name)
|
|
155
|
-
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
|
|
156
|
-
self.data_collector.update_api_or_module_name(full_name)
|
|
157
|
-
if getattr(self.config, "online_run_ut", False):
|
|
158
|
-
BaseHookManager.inner_switch = False
|
|
159
|
-
return
|
|
160
|
-
self.data_collector.forward_input_data_collect(
|
|
161
|
-
full_name,
|
|
162
|
-
module,
|
|
163
|
-
self._pid,
|
|
164
|
-
module_input_output,
|
|
165
|
-
self._is_recompute
|
|
166
|
-
)
|
|
167
|
-
BaseHookManager.inner_switch = False
|
|
168
|
-
return forward_pre_hook
|
|
169
|
-
|
|
170
|
-
def _build_forward_hook(self, hook_type, full_name):
|
|
171
|
-
def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None):
|
|
172
|
-
if not self._should_execute_hook(hook_type, module, True):
|
|
173
|
-
self._clear_input_kwargs(module)
|
|
174
|
-
return None
|
|
175
|
-
kwargs, output = self._process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs)
|
|
176
|
-
BaseHookManager.inner_switch = True
|
|
177
|
-
self.data_collector.update_api_or_module_name(full_name)
|
|
178
|
-
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
|
|
179
|
-
with self._no_grad_context():
|
|
180
|
-
if getattr(self.config, "online_run_ut", False):
|
|
181
|
-
if self.data_collector.scope and not self.data_collector.scope.check(full_name):
|
|
182
|
-
return None
|
|
183
|
-
if self.attl_manager:
|
|
184
|
-
self.attl_manager.attl_send(full_name, args, kwargs, output)
|
|
185
|
-
BaseHookManager.inner_switch = False
|
|
186
|
-
return None
|
|
187
|
-
if hook_type == Const.MODULE:
|
|
188
|
-
params_dict = self._get_params_dict(module)
|
|
189
|
-
setattr(module_input_output, Const.PARAMS, params_dict)
|
|
190
|
-
if params_dict:
|
|
191
|
-
self._register_param_hook(full_name, module, params_dict)
|
|
172
|
+
try:
|
|
173
|
+
with self._no_grad_context():
|
|
174
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
|
|
192
175
|
self.data_collector.update_api_or_module_name(full_name)
|
|
193
|
-
self.data_collector.
|
|
194
|
-
full_name,
|
|
195
|
-
module,
|
|
196
|
-
self._pid,
|
|
197
|
-
module_input_output,
|
|
198
|
-
self._is_recompute
|
|
199
|
-
)
|
|
200
|
-
self._init_params_grad_info(module, params_dict)
|
|
201
|
-
else:
|
|
202
|
-
self.data_collector.forward_output_data_collect(
|
|
176
|
+
self.data_collector.forward_input_data_collect(
|
|
203
177
|
full_name,
|
|
204
178
|
module,
|
|
205
179
|
self._pid,
|
|
206
180
|
module_input_output,
|
|
207
181
|
self._is_recompute
|
|
208
182
|
)
|
|
183
|
+
except Exception as e:
|
|
184
|
+
logger.error(f"The forward pre hook execution of the {full_name} API failed.")
|
|
185
|
+
raise e
|
|
186
|
+
finally:
|
|
187
|
+
BaseHookManager.inner_switch[tid] = False
|
|
188
|
+
ThreadSafe.release()
|
|
189
|
+
|
|
190
|
+
return forward_pre_hook
|
|
191
|
+
|
|
192
|
+
def _build_forward_hook(self, hook_type, full_name):
|
|
193
|
+
def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None):
|
|
194
|
+
tid = threading.get_ident()
|
|
195
|
+
if not self._should_execute_hook(hook_type, module, True, tid):
|
|
209
196
|
self._clear_input_kwargs(module)
|
|
197
|
+
return None
|
|
210
198
|
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
199
|
+
with ThreadSafe():
|
|
200
|
+
kwargs, output = self._process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs)
|
|
201
|
+
BaseHookManager.inner_switch[tid] = True
|
|
202
|
+
self.data_collector.update_api_or_module_name(full_name)
|
|
203
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
|
|
204
|
+
with self._no_grad_context():
|
|
205
|
+
if getattr(self.config, "online_run_ut", False):
|
|
206
|
+
if self.data_collector.scope and not self.data_collector.scope.check(full_name):
|
|
207
|
+
return None
|
|
208
|
+
if self.attl_manager:
|
|
209
|
+
self.attl_manager.attl_send(full_name, args, kwargs, output)
|
|
210
|
+
BaseHookManager.inner_switch[tid] = False
|
|
211
|
+
return None
|
|
212
|
+
if hook_type == Const.MODULE:
|
|
213
|
+
params_dict = self._get_params_dict(module)
|
|
214
|
+
setattr(module_input_output, Const.PARAMS, params_dict)
|
|
215
|
+
if params_dict:
|
|
216
|
+
self._register_param_hook(full_name, module, params_dict)
|
|
217
|
+
self.data_collector.update_api_or_module_name(full_name)
|
|
218
|
+
self.data_collector.forward_data_collect(
|
|
219
|
+
full_name,
|
|
220
|
+
module,
|
|
221
|
+
self._pid,
|
|
222
|
+
module_input_output,
|
|
223
|
+
self._is_recompute
|
|
224
|
+
)
|
|
225
|
+
self._init_params_grad_info(module, params_dict)
|
|
226
|
+
else:
|
|
227
|
+
self.data_collector.forward_output_data_collect(
|
|
228
|
+
full_name,
|
|
229
|
+
module,
|
|
230
|
+
self._pid,
|
|
231
|
+
module_input_output,
|
|
232
|
+
self._is_recompute
|
|
233
|
+
)
|
|
234
|
+
self._clear_input_kwargs(module)
|
|
235
|
+
|
|
236
|
+
if self.data_collector.if_return_forward_new_output():
|
|
237
|
+
forward_new_output = self.data_collector.get_forward_new_output()
|
|
238
|
+
BaseHookManager.inner_switch[tid] = False
|
|
239
|
+
return forward_new_output
|
|
240
|
+
|
|
241
|
+
BaseHookManager.inner_switch[tid] = False
|
|
242
|
+
return output
|
|
215
243
|
|
|
216
|
-
BaseHookManager.inner_switch = False
|
|
217
|
-
return output
|
|
218
244
|
return forward_hook
|
|
219
245
|
|
|
220
246
|
def _build_backward_hook(self, hook_type, full_name):
|
|
221
247
|
def backward_hook(module, grad_input, grad_output):
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
BaseHookManager.inner_switch = True
|
|
225
|
-
self.data_collector.update_api_or_module_name(full_name)
|
|
226
|
-
if getattr(self.config, "online_run_ut", False):
|
|
227
|
-
BaseHookManager.inner_switch = False
|
|
248
|
+
tid = threading.get_ident()
|
|
249
|
+
if not self._should_execute_hook(hook_type, module, False, tid):
|
|
228
250
|
return
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
251
|
+
|
|
252
|
+
with ThreadSafe():
|
|
253
|
+
BaseHookManager.inner_switch[tid] = True
|
|
254
|
+
self.data_collector.update_api_or_module_name(full_name)
|
|
255
|
+
if getattr(self.config, "online_run_ut", False):
|
|
256
|
+
BaseHookManager.inner_switch[tid] = False
|
|
257
|
+
return
|
|
258
|
+
need_exchange = self._need_exchange(module) if hook_type == Const.MODULE else True
|
|
259
|
+
if need_exchange:
|
|
260
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
|
|
261
|
+
else:
|
|
262
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
|
|
263
|
+
self.data_collector.backward_data_collect(
|
|
235
264
|
full_name,
|
|
236
265
|
module,
|
|
237
266
|
self._pid,
|
|
238
267
|
module_input_output,
|
|
239
268
|
self._is_recompute
|
|
240
269
|
)
|
|
241
|
-
|
|
270
|
+
BaseHookManager.inner_switch[tid] = False
|
|
271
|
+
|
|
242
272
|
return backward_hook
|
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
from collections import namedtuple
|
|
16
|
+
from datetime import timezone, timedelta
|
|
17
|
+
from functools import wraps
|
|
18
|
+
from datetime import datetime
|
|
19
|
+
import os
|
|
20
|
+
import re
|
|
21
|
+
|
|
22
|
+
from msprobe.core.common.const import MonitorConst
|
|
23
|
+
from msprobe.core.common.log import logger
|
|
24
|
+
from msprobe.core.common.utils import is_int
|
|
25
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
beijing_tz = timezone(timedelta(hours=8))
|
|
29
|
+
MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio"))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MsgConst:
|
|
33
|
+
"""
|
|
34
|
+
Class for log messages const
|
|
35
|
+
"""
|
|
36
|
+
SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_output_base_dir():
|
|
40
|
+
return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def filter_special_chars(func):
|
|
44
|
+
@wraps(func)
|
|
45
|
+
def func_level(msg):
|
|
46
|
+
for char in MsgConst.SPECIAL_CHAR:
|
|
47
|
+
msg = msg.replace(char, '_')
|
|
48
|
+
return func(msg)
|
|
49
|
+
|
|
50
|
+
return func_level
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def validate_ops(ops):
|
|
54
|
+
if not isinstance(ops, list):
|
|
55
|
+
raise TypeError("ops should be a list")
|
|
56
|
+
valid_ops = []
|
|
57
|
+
for op in ops:
|
|
58
|
+
if op not in MonitorConst.OP_LIST:
|
|
59
|
+
logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}")
|
|
60
|
+
continue
|
|
61
|
+
valid_ops.append(op)
|
|
62
|
+
if not valid_ops:
|
|
63
|
+
default_op = MonitorConst.OP_LIST[0]
|
|
64
|
+
valid_ops.append(default_op)
|
|
65
|
+
logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used")
|
|
66
|
+
# 增加默认shape和dtype参数
|
|
67
|
+
if "shape" not in valid_ops:
|
|
68
|
+
valid_ops.append("shape")
|
|
69
|
+
if "dtype" not in valid_ops:
|
|
70
|
+
valid_ops.append("dtype")
|
|
71
|
+
return valid_ops
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def validate_ndigits(ndigits):
|
|
75
|
+
if not ndigits:
|
|
76
|
+
return
|
|
77
|
+
if not is_int(ndigits) or ndigits <= 0:
|
|
78
|
+
raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.")
|
|
79
|
+
if ndigits > MonitorConst.MAX_NDIGITS:
|
|
80
|
+
raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def validate_ranks(ranks):
|
|
84
|
+
if not isinstance(ranks, list):
|
|
85
|
+
raise TypeError("module_ranks should be a list")
|
|
86
|
+
for rank in ranks:
|
|
87
|
+
if not isinstance(rank, int) or isinstance(rank, bool):
|
|
88
|
+
raise TypeError(f"element in module_ranks should be a int, get {type(rank)}")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def validate_targets(targets):
|
|
92
|
+
if not isinstance(targets, dict):
|
|
93
|
+
raise TypeError('targets in config.json should be a dict')
|
|
94
|
+
for module_name, field in targets.items():
|
|
95
|
+
if not isinstance(module_name, str):
|
|
96
|
+
raise TypeError('key of targets should be module_name[str] in config.json')
|
|
97
|
+
if not isinstance(field, dict):
|
|
98
|
+
raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json')
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def validate_print_struct(print_struct):
|
|
102
|
+
if not isinstance(print_struct, bool):
|
|
103
|
+
raise TypeError("print_struct should be a bool")
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def validate_ur_distribution(ur_distribution):
|
|
107
|
+
if not isinstance(ur_distribution, bool):
|
|
108
|
+
raise TypeError('ur_distribution should be a bool')
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def validate_xy_distribution(xy_distribution):
|
|
112
|
+
if not isinstance(xy_distribution, bool):
|
|
113
|
+
raise TypeError('xy_distribution should be a bool')
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def validate_wg_distribution(wg_distribution):
|
|
117
|
+
if not isinstance(wg_distribution, bool):
|
|
118
|
+
raise TypeError('wg_distribution should be a bool')
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def validate_mg_distribution(mg_distribution):
|
|
122
|
+
if not isinstance(mg_distribution, bool):
|
|
123
|
+
raise TypeError('mg_distribution should be a bool')
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def validate_param_distribution(param_distribution):
|
|
127
|
+
if not isinstance(param_distribution, bool):
|
|
128
|
+
raise TypeError('param_distribution should be a bool')
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def validate_cc_distribution(cc_distribution):
|
|
132
|
+
if not isinstance(cc_distribution, dict):
|
|
133
|
+
raise TypeError('cc_distribution should be a dictionary')
|
|
134
|
+
for key, value in cc_distribution.items():
|
|
135
|
+
if key == 'enable':
|
|
136
|
+
if not isinstance(value, bool):
|
|
137
|
+
raise TypeError('cc_distribution enable should be a bool')
|
|
138
|
+
elif key == 'cc_codeline':
|
|
139
|
+
if not isinstance(value, list):
|
|
140
|
+
raise TypeError('cc_distribution cc_codeline should be a list')
|
|
141
|
+
elif key == 'cc_pre_hook':
|
|
142
|
+
if not isinstance(value, bool):
|
|
143
|
+
raise TypeError('cc_distribution cc_pre_hook should be a bool')
|
|
144
|
+
elif key == 'cc_log_only':
|
|
145
|
+
if not isinstance(value, bool):
|
|
146
|
+
raise TypeError('cc_distribution cc_log_only should be a bool')
|
|
147
|
+
else:
|
|
148
|
+
raise TypeError(f'{key} of cc_distribution is not supported.')
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def validate_squash_name(squash_name):
|
|
152
|
+
if not isinstance(squash_name, bool):
|
|
153
|
+
raise TypeError('squash_name should be a bool')
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def validate_alert(alert):
|
|
157
|
+
if not isinstance(alert, dict):
|
|
158
|
+
raise TypeError('alert should be a dictionary')
|
|
159
|
+
rules = alert.get('rules')
|
|
160
|
+
if rules and isinstance(rules, list):
|
|
161
|
+
for rule in rules:
|
|
162
|
+
rule_name = rule.get("rule_name")
|
|
163
|
+
if rule_name and rule_name not in MonitorConst.RULE_NAME:
|
|
164
|
+
raise TypeError(f"{rule_name} is not supported")
|
|
165
|
+
args = rule.get("args")
|
|
166
|
+
if args and isinstance(args, dict):
|
|
167
|
+
threshold = args.get("threshold")
|
|
168
|
+
if not isinstance(threshold, (float, int)) or threshold < 0:
|
|
169
|
+
raise TypeError('threshold must be float and not less than 0')
|
|
170
|
+
dump = alert.get('dump')
|
|
171
|
+
if dump and not isinstance(dump, bool):
|
|
172
|
+
raise TypeError('dump must be bool.')
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def validate_step_count_per_record(step_count_per_record):
|
|
176
|
+
if not is_int(step_count_per_record):
|
|
177
|
+
raise TypeError('step_count_per_record must be int.')
|
|
178
|
+
if step_count_per_record < 1:
|
|
179
|
+
raise ValueError("step_count_per_record must greater than 0")
|
|
180
|
+
if step_count_per_record > 1e6:
|
|
181
|
+
raise ValueError("step_count_per_record must smaller than 1e6")
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def validate_dynamic_on(dynamic_on):
|
|
185
|
+
if not isinstance(dynamic_on, bool):
|
|
186
|
+
raise TypeError('dynamic_on should be a bool')
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def validate_monitor_mbs_grad(monitor_mbs_grad):
|
|
190
|
+
if not isinstance(monitor_mbs_grad, bool):
|
|
191
|
+
logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.')
|
|
192
|
+
return False
|
|
193
|
+
return monitor_mbs_grad
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def validate_append_output(append_output):
|
|
197
|
+
if not isinstance(append_output, list):
|
|
198
|
+
raise TypeError('append_output should be a list')
|
|
199
|
+
if len(append_output) > 0 and len(append_output) != 2:
|
|
200
|
+
raise ValueError('append_output should be empty or contain exactly 2 elements')
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def validate_config(config):
|
|
204
|
+
config['ops'] = validate_ops(config.get('ops', []))
|
|
205
|
+
|
|
206
|
+
ndigits = config.get('ndigits')
|
|
207
|
+
validate_ndigits(ndigits)
|
|
208
|
+
|
|
209
|
+
eps = config.get('eps', 1e-8)
|
|
210
|
+
if not isinstance(eps, float):
|
|
211
|
+
raise TypeError("eps should be a float")
|
|
212
|
+
|
|
213
|
+
ranks = config.get("module_ranks", [])
|
|
214
|
+
validate_ranks(ranks)
|
|
215
|
+
|
|
216
|
+
targets = config.get("targets", {})
|
|
217
|
+
validate_targets(targets)
|
|
218
|
+
|
|
219
|
+
print_struct = config.get('print_struct', False)
|
|
220
|
+
validate_print_struct(print_struct)
|
|
221
|
+
|
|
222
|
+
ur_distribution = config.get('ur_distribution', False)
|
|
223
|
+
validate_ur_distribution(ur_distribution)
|
|
224
|
+
|
|
225
|
+
xy_distribution = config.get('xy_distribution', False)
|
|
226
|
+
validate_xy_distribution(xy_distribution)
|
|
227
|
+
|
|
228
|
+
wg_distribution = config.get('wg_distribution', False)
|
|
229
|
+
validate_wg_distribution(wg_distribution)
|
|
230
|
+
|
|
231
|
+
mg_distribution = config.get('mg_distribution', False)
|
|
232
|
+
validate_mg_distribution(mg_distribution)
|
|
233
|
+
|
|
234
|
+
param_distribution = config.get('param_distribution', False)
|
|
235
|
+
validate_param_distribution(param_distribution)
|
|
236
|
+
|
|
237
|
+
cc_distribution = config.get('cc_distribution', {})
|
|
238
|
+
validate_cc_distribution(cc_distribution)
|
|
239
|
+
|
|
240
|
+
alert = config.get('alert', {})
|
|
241
|
+
validate_alert(alert)
|
|
242
|
+
|
|
243
|
+
step_count_per_record = config.get('step_count_per_record', 1)
|
|
244
|
+
validate_step_count_per_record(step_count_per_record)
|
|
245
|
+
|
|
246
|
+
config["start_step"] = validate_int_arg(config.get("start_step"), "start_step",
|
|
247
|
+
MonitorConst.DEFAULT_START_STEP, MonitorConst.DEFAULT_START_STEP)
|
|
248
|
+
config["collect_times"] = validate_int_arg(config.get("collect_times"), "collect_times",
|
|
249
|
+
MonitorConst.DEFAULT_MIN_COLLECT_TIMES,
|
|
250
|
+
MonitorConst.DEFAULT_MAX_COLLECT_TIMES)
|
|
251
|
+
config["step_interval"] = validate_int_arg(config.get("step_interval"), "step_interval",
|
|
252
|
+
MonitorConst.DEFAULT_STEP_INTERVAL, MonitorConst.DEFAULT_STEP_INTERVAL)
|
|
253
|
+
|
|
254
|
+
squash_name = config.get('squash_name', True)
|
|
255
|
+
validate_squash_name(squash_name)
|
|
256
|
+
|
|
257
|
+
time_tags = config.get("append_output", [])
|
|
258
|
+
validate_append_output(time_tags)
|
|
259
|
+
|
|
260
|
+
config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False))
|
|
261
|
+
|
|
262
|
+
dynamic_on = config.get('dynamic_on', False)
|
|
263
|
+
validate_dynamic_on(dynamic_on)
|
|
264
|
+
|
|
265
|
+
if not targets:
|
|
266
|
+
if xy_distribution:
|
|
267
|
+
config["all_xy"] = True
|
|
268
|
+
config["targets"] = {"": {}}
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def time_str2time_digit(time_str):
|
|
272
|
+
time_format = '%b%d_%H-%M-%S'
|
|
273
|
+
if not isinstance(time_str, str):
|
|
274
|
+
raise TypeError(f"time_str:{time_str} should be a str")
|
|
275
|
+
try:
|
|
276
|
+
time_digit = datetime.strptime(time_str, time_format)
|
|
277
|
+
except Exception as e:
|
|
278
|
+
raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \
|
|
279
|
+
of existing output dirpath, like 'Dec03_21-34-40'.") from e
|
|
280
|
+
return time_digit
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def get_target_output_dir(monitor_path, time_start, time_end):
|
|
284
|
+
check_file_or_directory_path(monitor_path, isdir=True)
|
|
285
|
+
time_start = time_str2time_digit(time_start) if time_start is not None else time_start
|
|
286
|
+
time_end = time_str2time_digit(time_end) if time_end is not None else time_end
|
|
287
|
+
if time_start and time_end and time_start > time_end:
|
|
288
|
+
raise ValueError(f"time_start({time_start}) greater than time_end({time_end})")
|
|
289
|
+
result = {}
|
|
290
|
+
for dirname in os.listdir(monitor_path):
|
|
291
|
+
match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname)
|
|
292
|
+
if not match:
|
|
293
|
+
continue
|
|
294
|
+
time_tag = match.group(1)
|
|
295
|
+
rank = match.group(2)
|
|
296
|
+
target_time = time_str2time_digit(time_tag)
|
|
297
|
+
start_ok = time_start is None or target_time >= time_start
|
|
298
|
+
end_ok = time_end is None or target_time <= time_end
|
|
299
|
+
if start_ok and end_ok:
|
|
300
|
+
result[rank] = os.path.join(monitor_path, dirname)
|
|
301
|
+
return result
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def chmod_tensorboard_dir(path):
|
|
305
|
+
"""
|
|
306
|
+
format配置为tensorboard时,需要补充文件权限设置
|
|
307
|
+
"""
|
|
308
|
+
try:
|
|
309
|
+
recursive_chmod(path)
|
|
310
|
+
except Exception as e:
|
|
311
|
+
logger.warning(f"chmod tensorboard dir wrong because {e}, not updated, please check!!!")
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def validate_set_monitor(grad_acc_steps, start_iteration):
|
|
315
|
+
"""
|
|
316
|
+
validate parameters of set_monitor.
|
|
317
|
+
"""
|
|
318
|
+
grad_acc_steps = validate_int_arg(grad_acc_steps, "grad_acc_steps",
|
|
319
|
+
MonitorConst.DEFAULT_GRAD_ACC_STEPS, MonitorConst.DEFAULT_GRAD_ACC_STEPS)
|
|
320
|
+
|
|
321
|
+
start_iteration = validate_int_arg(start_iteration, "start_iteration",
|
|
322
|
+
MonitorConst.DEFAULT_START_ITERATION, MonitorConst.DEFAULT_START_ITERATION)
|
|
323
|
+
return grad_acc_steps, start_iteration
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def validate_int_arg(value, name, minimum, default_value):
|
|
327
|
+
"""Validate int args, if any exception occurs, use the default value."""
|
|
328
|
+
if value is None:
|
|
329
|
+
return default_value
|
|
330
|
+
try:
|
|
331
|
+
if not is_int(value):
|
|
332
|
+
raise TypeError(f"{name} must be int")
|
|
333
|
+
if value < minimum:
|
|
334
|
+
raise ValueError(f"{name} must greater than {minimum}")
|
|
335
|
+
except Exception as e:
|
|
336
|
+
value = default_value
|
|
337
|
+
logger.warning(f"Validate {name} failed, {e}, replaced with default value {value}.")
|
|
338
|
+
return value
|
msprobe/core/service.py
CHANGED
|
@@ -303,7 +303,8 @@ class BaseService(ABC):
|
|
|
303
303
|
self.logger.info(
|
|
304
304
|
f"Current token id: {self.cur_token_id}, exceed token_range, early stop dump infer token.")
|
|
305
305
|
self.cur_token_id += 1
|
|
306
|
-
|
|
306
|
+
# 此处root_model可以保证为 Module/Cell类型 或 [Module/Cell]类型
|
|
307
|
+
if root_model and isinstance(root_model, list):
|
|
307
308
|
root_model = root_model[0]
|
|
308
309
|
self.logger.warning("Infer model can only input one to support token_range, choose the first one.")
|
|
309
310
|
if self._is_online_run_ut:
|