mindstudio-probe 8.1.1__py3-none-any.whl → 8.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/METADATA +1 -1
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/RECORD +95 -94
- msprobe/core/common/const.py +3 -0
- msprobe/core/common/file_utils.py +45 -5
- msprobe/core/common/utils.py +117 -13
- msprobe/core/common_config.py +15 -1
- msprobe/core/compare/acc_compare.py +21 -9
- msprobe/core/compare/compare_cli.py +10 -2
- msprobe/core/compare/merge_result/merge_result.py +1 -1
- msprobe/core/compare/utils.py +8 -2
- msprobe/core/config_check/checkers/base_checker.py +2 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +5 -4
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +4 -1
- msprobe/core/config_check/config_check_cli.py +1 -1
- msprobe/core/config_check/config_checker.py +1 -2
- msprobe/core/data_dump/data_collector.py +4 -1
- msprobe/core/data_dump/data_processor/mindspore_processor.py +23 -1
- msprobe/core/data_dump/data_processor/pytorch_processor.py +3 -25
- msprobe/core/debugger/precision_debugger.py +13 -8
- msprobe/core/hook_manager.py +112 -82
- msprobe/core/monitor/utils.py +338 -0
- msprobe/core/service.py +2 -1
- msprobe/core/single_save/single_comparator.py +5 -3
- msprobe/docs/01.installation.md +1 -0
- msprobe/docs/05.data_dump_PyTorch.md +4 -4
- msprobe/docs/07.accuracy_checker_PyTorch.md +14 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +13 -11
- msprobe/docs/10.accuracy_compare_PyTorch.md +3 -1
- msprobe/docs/11.accuracy_compare_MindSpore.md +4 -2
- msprobe/docs/12.overflow_check_PyTorch.md +3 -2
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/14.data_parse_PyTorch.md +35 -32
- msprobe/docs/21.visualization_PyTorch.md +9 -8
- msprobe/docs/22.visualization_MindSpore.md +1 -0
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/24.code_mapping_Mindspore.md +6 -5
- msprobe/docs/31.config_check.md +15 -5
- msprobe/docs/33.generate_operator_MindSpore.md +2 -2
- msprobe/docs/34.RL_collect.md +18 -9
- msprobe/docs/35.nan_analyze.md +4 -3
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +29 -1
- msprobe/mindspore/cell_processor.py +35 -14
- msprobe/mindspore/code_mapping/bind.py +23 -4
- msprobe/mindspore/code_mapping/graph_parser.py +6 -4
- msprobe/mindspore/common/utils.py +3 -0
- msprobe/mindspore/compare/common_dir_compare.py +32 -12
- msprobe/mindspore/compare/ms_graph_compare.py +7 -2
- msprobe/mindspore/compare/utils.py +9 -1
- msprobe/mindspore/debugger/debugger_config.py +13 -11
- msprobe/mindspore/debugger/precision_debugger.py +67 -45
- msprobe/mindspore/dump/dump_tool_factory.py +2 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +14 -9
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +12 -7
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +27 -13
- msprobe/mindspore/dump/jit_dump.py +6 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +13 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +6 -5
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -0
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/monitor/common_func.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +3 -3
- msprobe/mindspore/monitor/utils.py +0 -252
- msprobe/mindspore/ms_config.py +0 -1
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/nan_analyze/graph.py +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +15 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +1 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +1 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -4
- msprobe/pytorch/common/utils.py +0 -16
- msprobe/pytorch/compare/pt_compare.py +5 -0
- msprobe/pytorch/debugger/debugger_config.py +12 -5
- msprobe/pytorch/debugger/precision_debugger.py +8 -1
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +1 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +44 -13
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +2 -0
- msprobe/pytorch/hook_module/hook_module.py +9 -9
- msprobe/pytorch/hook_module/pt_hook_manager.py +7 -7
- msprobe/pytorch/monitor/csv2tb.py +3 -10
- msprobe/pytorch/monitor/features.py +5 -0
- msprobe/pytorch/monitor/module_hook.py +6 -7
- msprobe/pytorch/monitor/module_metric.py +0 -3
- msprobe/pytorch/monitor/optimizer_collect.py +1 -1
- msprobe/pytorch/monitor/utils.py +1 -317
- msprobe/pytorch/online_dispatch/dispatch.py +1 -1
- msprobe/pytorch/online_dispatch/dump_compare.py +7 -1
- msprobe/pytorch/parse_tool/lib/utils.py +2 -4
- msprobe/visualization/graph_service.py +1 -1
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/top_level.txt +0 -0
|
@@ -20,7 +20,7 @@ import mindspore as ms
|
|
|
20
20
|
from mindspore._c_expression import MSContext
|
|
21
21
|
|
|
22
22
|
from msprobe.core.common.const import Const, MsgConst
|
|
23
|
-
from msprobe.core.common.utils import check_token_range
|
|
23
|
+
from msprobe.core.common.utils import check_token_range, ThreadSafe
|
|
24
24
|
from msprobe.core.common.runtime import Runtime
|
|
25
25
|
from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger
|
|
26
26
|
from msprobe.mindspore.cell_processor import CellProcessor
|
|
@@ -57,18 +57,14 @@ ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task", "dump_
|
|
|
57
57
|
|
|
58
58
|
class PrecisionDebugger(BasePrecisionDebugger):
|
|
59
59
|
|
|
60
|
-
def
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
return cls._instance
|
|
69
|
-
|
|
70
|
-
def __init__(self, config_path=None, task=None, dump_path=None,
|
|
71
|
-
level=None, step=None):
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
config_path=None,
|
|
63
|
+
task=None,
|
|
64
|
+
dump_path=None,
|
|
65
|
+
level=None,
|
|
66
|
+
step=None
|
|
67
|
+
):
|
|
72
68
|
if self.initialized:
|
|
73
69
|
return
|
|
74
70
|
set_register_backward_hook_functions()
|
|
@@ -81,13 +77,15 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
81
77
|
self.common_config.dump_path = dump_path if dump_path else self.common_config.dump_path
|
|
82
78
|
self.config = DebuggerConfig(self.common_config, self.task_config)
|
|
83
79
|
|
|
84
|
-
if self.
|
|
80
|
+
if self._is_kernel_dump() and not self.task_config.is_regex_valid:
|
|
81
|
+
raise ValueError('Illegal regular expressions exist in the list.')
|
|
82
|
+
|
|
83
|
+
if self._is_kernel_dump() and _msprobe_c:
|
|
85
84
|
os.environ["MS_HOOK_ENABLE"] = "on"
|
|
86
85
|
_msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path)
|
|
87
86
|
|
|
88
87
|
self.config.execution_mode = self._get_execution_mode()
|
|
89
88
|
if self._need_service():
|
|
90
|
-
self.config.check_config_with_l2()
|
|
91
89
|
self.service = MindsporeService(self.config)
|
|
92
90
|
|
|
93
91
|
Runtime.step_count = 0
|
|
@@ -119,8 +117,6 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
119
117
|
|
|
120
118
|
@staticmethod
|
|
121
119
|
def _is_graph_dump(config: DebuggerConfig):
|
|
122
|
-
if config.level != MsConst.KERNEL:
|
|
123
|
-
return False
|
|
124
120
|
if not config.list:
|
|
125
121
|
return True
|
|
126
122
|
is_graph = any(item.startswith("name-regex") for item in config.list)
|
|
@@ -132,28 +128,23 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
132
128
|
instance = cls._get_instance()
|
|
133
129
|
if instance is None:
|
|
134
130
|
return
|
|
135
|
-
if cls.
|
|
136
|
-
|
|
137
|
-
check_token_range(token_range)
|
|
138
|
-
instance.config.execution_mode = cls._get_execution_mode()
|
|
139
|
-
if cls._need_service():
|
|
140
|
-
if not instance.service:
|
|
141
|
-
instance.service = MindsporeService(instance.config)
|
|
142
|
-
instance.config.check_model(model, token_range)
|
|
143
|
-
instance.service.start(model, token_range)
|
|
131
|
+
if cls._is_kernel_dump():
|
|
132
|
+
cls._start_kernel_dump()
|
|
144
133
|
else:
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
if
|
|
155
|
-
|
|
156
|
-
|
|
134
|
+
check_token_range(token_range)
|
|
135
|
+
instance.config.execution_mode = cls._get_execution_mode()
|
|
136
|
+
if cls._need_service():
|
|
137
|
+
with ThreadSafe():
|
|
138
|
+
if not instance.service:
|
|
139
|
+
instance.service = MindsporeService(instance.config)
|
|
140
|
+
instance.config.check_model(model, token_range)
|
|
141
|
+
instance.service.start(model, token_range)
|
|
142
|
+
else:
|
|
143
|
+
if not instance.first_start:
|
|
144
|
+
get_api_register().restore_all_api()
|
|
145
|
+
handler = TaskHandlerFactory.create(instance.config, model)
|
|
146
|
+
handler.handle()
|
|
147
|
+
Runtime.is_running = True
|
|
157
148
|
instance.first_start = True
|
|
158
149
|
|
|
159
150
|
@classmethod
|
|
@@ -165,14 +156,15 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
165
156
|
if instance.task == Const.GRAD_PROBE:
|
|
166
157
|
instance.gm.stop()
|
|
167
158
|
if instance.service:
|
|
168
|
-
|
|
159
|
+
with ThreadSafe():
|
|
160
|
+
instance.service.stop()
|
|
169
161
|
else:
|
|
170
162
|
Runtime.is_running = False
|
|
171
163
|
if enable_dynamic_kbyk_dump:
|
|
172
164
|
_dump_stop()
|
|
173
|
-
if cls.
|
|
165
|
+
if cls._is_kernel_dump() and _msprobe_c:
|
|
174
166
|
_msprobe_c._PrecisionDebugger().stop()
|
|
175
|
-
|
|
167
|
+
|
|
176
168
|
@classmethod
|
|
177
169
|
def step(cls):
|
|
178
170
|
instance = cls._get_instance()
|
|
@@ -180,12 +172,13 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
180
172
|
return
|
|
181
173
|
|
|
182
174
|
if instance.service:
|
|
183
|
-
|
|
175
|
+
with ThreadSafe():
|
|
176
|
+
instance.service.step()
|
|
184
177
|
if is_graph_mode_cell_dump_allowed(instance.config):
|
|
185
178
|
GraphModeCellDump.step()
|
|
186
179
|
if enable_dynamic_kbyk_dump:
|
|
187
180
|
_dump_step(1)
|
|
188
|
-
if cls.
|
|
181
|
+
if cls._is_kernel_dump() and _msprobe_c:
|
|
189
182
|
_msprobe_c._PrecisionDebugger().step()
|
|
190
183
|
|
|
191
184
|
HOOKCell.cell_count = defaultdict(int)
|
|
@@ -193,6 +186,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
193
186
|
Runtime.step_count += 1
|
|
194
187
|
|
|
195
188
|
@classmethod
|
|
189
|
+
@ThreadSafe.synchronized
|
|
196
190
|
def monitor(cls, opt):
|
|
197
191
|
instance = cls._instance
|
|
198
192
|
if not instance:
|
|
@@ -202,6 +196,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
202
196
|
instance.gm.monitor(opt)
|
|
203
197
|
|
|
204
198
|
@classmethod
|
|
199
|
+
@ThreadSafe.synchronized
|
|
205
200
|
def save(cls, variable, name, save_backward=True):
|
|
206
201
|
instance = cls._instance
|
|
207
202
|
if not instance:
|
|
@@ -224,14 +219,41 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
224
219
|
instance = cls._instance
|
|
225
220
|
if not instance:
|
|
226
221
|
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
222
|
+
if instance.config.level_ori == Const.LEVEL_L2:
|
|
223
|
+
return False
|
|
227
224
|
if instance.config.execution_mode != MsConst.PYNATIVE_MODE:
|
|
228
225
|
return False
|
|
229
226
|
else:
|
|
230
|
-
return instance.config.task != Const.FREE_BENCHMARK
|
|
227
|
+
return instance.config.task != Const.FREE_BENCHMARK
|
|
231
228
|
|
|
232
229
|
@classmethod
|
|
233
|
-
def
|
|
230
|
+
def _is_kernel_dump(cls):
|
|
234
231
|
instance = cls._instance
|
|
235
232
|
if not instance:
|
|
236
233
|
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
237
234
|
return instance.config.level_ori == Const.LEVEL_L2
|
|
235
|
+
|
|
236
|
+
@classmethod
|
|
237
|
+
def _start_kernel_dump(cls):
|
|
238
|
+
instance = cls._get_instance()
|
|
239
|
+
is_graph_config = cls._is_graph_dump(instance.config)
|
|
240
|
+
instance.config.check_config_with_l2(is_graph_config)
|
|
241
|
+
if not is_graph_config:
|
|
242
|
+
if not instance.service:
|
|
243
|
+
instance.service = MindsporeService(instance.config)
|
|
244
|
+
instance.service.start()
|
|
245
|
+
else:
|
|
246
|
+
if _msprobe_c:
|
|
247
|
+
_msprobe_c._PrecisionDebugger().start()
|
|
248
|
+
if not instance.first_start:
|
|
249
|
+
get_api_register().restore_all_api()
|
|
250
|
+
handlers = TaskHandlerFactory.create(instance.config)
|
|
251
|
+
for handler in handlers:
|
|
252
|
+
handler.handle()
|
|
253
|
+
if enable_dynamic_kbyk_dump:
|
|
254
|
+
_set_init_iter(0)
|
|
255
|
+
if enable_dynamic_kbyk_dump:
|
|
256
|
+
is_valid_rank = (not instance.config.rank or Runtime.rank_id in instance.config.rank)
|
|
257
|
+
is_valid_step = (not instance.config.step or Runtime.step_count in instance.config.step)
|
|
258
|
+
if is_valid_rank and is_valid_step:
|
|
259
|
+
_dump_start()
|
|
@@ -51,6 +51,8 @@ class DumpToolFactory:
|
|
|
51
51
|
else:
|
|
52
52
|
if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST:
|
|
53
53
|
raise Exception("data_mode must be one of all, input, output.")
|
|
54
|
+
if config.level == Const.KERNEL:
|
|
55
|
+
return (KernelGraphDump(config), KernelKbykDump(config))
|
|
54
56
|
tool = DumpToolFactory.tools.get(config.level)
|
|
55
57
|
if not tool:
|
|
56
58
|
raise Exception("Valid level is needed.")
|
|
@@ -13,15 +13,16 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import threading
|
|
16
17
|
from collections import defaultdict
|
|
17
18
|
|
|
18
19
|
import mindspore as ms
|
|
19
20
|
from mindspore import nn
|
|
20
21
|
|
|
21
22
|
from msprobe.core.common.runtime import Runtime
|
|
23
|
+
from msprobe.core.common.utils import ThreadSafe
|
|
22
24
|
from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions
|
|
23
25
|
|
|
24
|
-
|
|
25
26
|
ms_version = ms.__version__
|
|
26
27
|
|
|
27
28
|
|
|
@@ -35,16 +36,17 @@ def get_cell_count(name):
|
|
|
35
36
|
|
|
36
37
|
def __init__(self, hook_build_func) -> None:
|
|
37
38
|
super(HOOKCell, self).__init__()
|
|
38
|
-
self.changed_status = False
|
|
39
39
|
self.msprobe_input_kwargs = {}
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
40
|
+
|
|
41
|
+
self.tid = threading.get_ident()
|
|
42
|
+
self.stop_hook = HOOKCell.inner_stop_hook.get(self.tid, False)
|
|
43
|
+
if not self.stop_hook:
|
|
43
44
|
self.forward_data_collected = False
|
|
44
45
|
|
|
45
46
|
if not Runtime.is_running:
|
|
46
47
|
return
|
|
47
48
|
prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
|
|
49
|
+
ThreadSafe.acquire()
|
|
48
50
|
if callable(hook_build_func):
|
|
49
51
|
hook_set = hook_build_func(prefix)
|
|
50
52
|
if ms_version < "2.6.0" and not is_mindtorch():
|
|
@@ -59,21 +61,24 @@ def __init__(self, hook_build_func) -> None:
|
|
|
59
61
|
|
|
60
62
|
# 重载call,加全局标志。
|
|
61
63
|
def __call__(self, *args, **kwargs):
|
|
64
|
+
changed = False
|
|
65
|
+
if not self.stop_hook:
|
|
66
|
+
HOOKCell.inner_stop_hook[self.tid] = True
|
|
67
|
+
changed = True
|
|
62
68
|
try:
|
|
63
69
|
self.msprobe_input_kwargs = kwargs
|
|
64
70
|
out = super(HOOKCell, self).__call__(*args, **kwargs)
|
|
65
71
|
except Exception as e:
|
|
66
72
|
raise e
|
|
67
73
|
finally:
|
|
68
|
-
if
|
|
69
|
-
self.
|
|
70
|
-
HOOKCell.g_stop_hook = False
|
|
74
|
+
if changed:
|
|
75
|
+
HOOKCell.inner_stop_hook[self.tid] = False
|
|
71
76
|
return out
|
|
72
77
|
|
|
73
78
|
|
|
74
79
|
hook_cell_dict = {
|
|
75
80
|
"cell_count": defaultdict(int),
|
|
76
|
-
"
|
|
81
|
+
"inner_stop_hook": defaultdict(bool),
|
|
77
82
|
"add_cell_count": staticmethod(add_cell_count),
|
|
78
83
|
"get_cell_count": staticmethod(get_cell_count),
|
|
79
84
|
"__init__": __init__,
|
|
@@ -13,9 +13,11 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import threading
|
|
17
|
+
|
|
16
18
|
from mindspore.common.api import _no_grad
|
|
17
19
|
from msprobe.core.common.const import Const
|
|
18
|
-
from msprobe.core.common.utils import replace_last_occurrence
|
|
20
|
+
from msprobe.core.common.utils import replace_last_occurrence, ThreadSafe
|
|
19
21
|
from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputs
|
|
20
22
|
from msprobe.core.hook_manager import BaseHookManager, HookSet
|
|
21
23
|
from msprobe.mindspore.common.utils import has_kwargs_in_forward_hook
|
|
@@ -78,11 +80,14 @@ class MindsproeHookManager(BaseHookManager):
|
|
|
78
80
|
def backward_pre_hook(module, grad_input):
|
|
79
81
|
if self.config.level != Const.LEVEL_L2:
|
|
80
82
|
return
|
|
81
|
-
|
|
83
|
+
tid = threading.get_ident()
|
|
84
|
+
if not self._should_execute_hook(hook_type, module, False, tid):
|
|
82
85
|
return
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
86
|
+
|
|
87
|
+
with ThreadSafe():
|
|
88
|
+
BaseHookManager.inner_switch[tid] = True
|
|
89
|
+
module_input = ModuleBackwardInputs(grad_input=grad_input)
|
|
90
|
+
self.data_collector.update_api_or_module_name(name)
|
|
91
|
+
self.data_collector.backward_input_data_collect(name, module, self._pid, module_input)
|
|
92
|
+
BaseHookManager.inner_switch[tid] = False
|
|
88
93
|
return backward_pre_hook
|
|
@@ -14,13 +14,16 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
+
import threading
|
|
17
18
|
|
|
18
19
|
from mindspore import ops
|
|
19
20
|
from mindspore.common.tensor import Tensor
|
|
20
|
-
|
|
21
|
-
from msprobe.core.
|
|
22
|
-
|
|
23
|
-
|
|
21
|
+
from msprobe.core.common.utils import Const, DumpException, ThreadSafe
|
|
22
|
+
from msprobe.core.data_dump.data_processor.base import (
|
|
23
|
+
ModuleBackwardInputs,
|
|
24
|
+
ModuleBackwardOutputs,
|
|
25
|
+
ModuleForwardInputsOutputs
|
|
26
|
+
)
|
|
24
27
|
from msprobe.core.hook_manager import BaseHookManager
|
|
25
28
|
from msprobe.mindspore.common.log import logger
|
|
26
29
|
|
|
@@ -56,10 +59,13 @@ class PrimitiveHookService:
|
|
|
56
59
|
callable: 反向 hook 函数。
|
|
57
60
|
"""
|
|
58
61
|
|
|
62
|
+
@ThreadSafe.synchronized
|
|
59
63
|
def backward_hook(grad):
|
|
64
|
+
tid = threading.get_ident()
|
|
65
|
+
BaseHookManager.inner_switch[tid] = True
|
|
66
|
+
|
|
60
67
|
captured_grads.extend(grad)
|
|
61
68
|
backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}"
|
|
62
|
-
self.service_instance.inner_switch = True
|
|
63
69
|
try:
|
|
64
70
|
if hook_type == Const.INPUT:
|
|
65
71
|
self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
|
|
@@ -78,7 +84,7 @@ class PrimitiveHookService:
|
|
|
78
84
|
logger.error(f"This is a primitive op {hook_type}_backward dump error: {exception}, "
|
|
79
85
|
f"updated_primitive_name: {updated_primitive_name}")
|
|
80
86
|
raise DumpException(DumpException.BACKWARD_DATA_COLLECTION_ERROR) from exception
|
|
81
|
-
|
|
87
|
+
BaseHookManager.inner_switch[tid] = False
|
|
82
88
|
|
|
83
89
|
return backward_hook
|
|
84
90
|
|
|
@@ -137,9 +143,13 @@ class PrimitiveHookService:
|
|
|
137
143
|
return tuple(hooked_outputs)
|
|
138
144
|
return out
|
|
139
145
|
|
|
146
|
+
@ThreadSafe.synchronized
|
|
140
147
|
def pre_forward_hook(primitive_name, primitive_instance, args, kwargs):
|
|
148
|
+
tid = threading.get_ident()
|
|
149
|
+
BaseHookManager.inner_switch[tid] = True
|
|
150
|
+
|
|
151
|
+
self.service_instance.data_collector.update_api_or_module_name(primitive_name)
|
|
141
152
|
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
|
|
142
|
-
self.service_instance.inner_switch = True
|
|
143
153
|
try:
|
|
144
154
|
self.service_instance.data_collector.forward_input_data_collect(
|
|
145
155
|
primitive_name,
|
|
@@ -151,11 +161,15 @@ class PrimitiveHookService:
|
|
|
151
161
|
logger.error(f"This is a primitive op dump error during forward input data collection: {exception}, "
|
|
152
162
|
f"primitive_name: {primitive_name}")
|
|
153
163
|
raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
|
|
154
|
-
|
|
164
|
+
BaseHookManager.inner_switch[tid] = False
|
|
155
165
|
|
|
166
|
+
@ThreadSafe.synchronized
|
|
156
167
|
def post_forward_hook(primitive_name, primitive_instance, args, kwargs, output):
|
|
168
|
+
tid = threading.get_ident()
|
|
169
|
+
BaseHookManager.inner_switch[tid] = True
|
|
170
|
+
|
|
171
|
+
self.service_instance.data_collector.update_api_or_module_name(primitive_name)
|
|
157
172
|
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
|
|
158
|
-
self.service_instance.inner_switch = True
|
|
159
173
|
try:
|
|
160
174
|
self.service_instance.data_collector.forward_output_data_collect(
|
|
161
175
|
primitive_name,
|
|
@@ -167,7 +181,7 @@ class PrimitiveHookService:
|
|
|
167
181
|
logger.error(f"This is a primitive op dump error during forward output data collection: {exception}, "
|
|
168
182
|
f"primitive_name: {primitive_name}")
|
|
169
183
|
raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
|
|
170
|
-
|
|
184
|
+
BaseHookManager.inner_switch[tid] = False
|
|
171
185
|
|
|
172
186
|
def wrapped_primitive_call(instance_self, *args, **kwargs):
|
|
173
187
|
"""
|
|
@@ -185,7 +199,8 @@ class PrimitiveHookService:
|
|
|
185
199
|
current_count = self.primitive_counters.get(primitive_name, 0)
|
|
186
200
|
updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}{Const.SEP}{primitive_name}{Const.SEP}{current_count}"
|
|
187
201
|
|
|
188
|
-
|
|
202
|
+
tid = threading.get_ident()
|
|
203
|
+
if not self.service_instance.primitive_switch or BaseHookManager.inner_switch[tid]:
|
|
189
204
|
return origin_func(*args, **kwargs)
|
|
190
205
|
|
|
191
206
|
captured_grads_input, captured_grads_output = [], []
|
|
@@ -198,8 +213,6 @@ class PrimitiveHookService:
|
|
|
198
213
|
raise DumpException(DumpException.INPUT_HOOK_ERROR) from exception
|
|
199
214
|
|
|
200
215
|
forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}"
|
|
201
|
-
self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
|
|
202
|
-
|
|
203
216
|
pre_forward_hook(forward_primitive_name, instance_self, hooked_inputs, kwargs)
|
|
204
217
|
try:
|
|
205
218
|
out = origin_func(*hooked_inputs, **kwargs)
|
|
@@ -220,6 +233,7 @@ class PrimitiveHookService:
|
|
|
220
233
|
|
|
221
234
|
return wrapped_primitive_call
|
|
222
235
|
|
|
236
|
+
@ThreadSafe.synchronized
|
|
223
237
|
def update_primitive_counters(self, primitive_name):
|
|
224
238
|
if primitive_name not in self.primitive_counters:
|
|
225
239
|
self.primitive_counters[primitive_name] = 0
|
|
@@ -13,13 +13,14 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
from collections import defaultdict
|
|
17
16
|
import os
|
|
18
17
|
import types
|
|
18
|
+
from collections import defaultdict
|
|
19
19
|
|
|
20
20
|
import mindspore
|
|
21
21
|
from mindspore import nn
|
|
22
22
|
from mindspore._c_expression import PyNativeExecutor_
|
|
23
|
+
|
|
23
24
|
try:
|
|
24
25
|
from mindspore.common.api import _MindsporeFunctionExecutor
|
|
25
26
|
except ImportError:
|
|
@@ -27,19 +28,21 @@ except ImportError:
|
|
|
27
28
|
|
|
28
29
|
from msprobe.core.common.log import logger
|
|
29
30
|
from msprobe.core.common.const import Const
|
|
31
|
+
from msprobe.core.common.utils import ThreadSafe
|
|
30
32
|
from msprobe.core.common.runtime import Runtime
|
|
31
33
|
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
32
34
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
33
35
|
from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
|
|
34
36
|
|
|
35
|
-
|
|
36
37
|
_api_register = get_api_register()
|
|
37
38
|
|
|
38
39
|
|
|
39
40
|
def dump_jit(name, in_feat, out_feat, is_forward):
|
|
40
41
|
pid = os.getpid()
|
|
41
42
|
name = name if name else "JitFunction"
|
|
42
|
-
if JitDump.need_dump():
|
|
43
|
+
if not JitDump.need_dump():
|
|
44
|
+
return
|
|
45
|
+
with ThreadSafe():
|
|
43
46
|
if is_forward:
|
|
44
47
|
if name in JitDump.jit_count:
|
|
45
48
|
JitDump.jit_count[name] += 1
|
|
@@ -39,12 +39,19 @@ class KernelKbykDump:
|
|
|
39
39
|
common_set["input_output"] = 0
|
|
40
40
|
common_set["kernels"] = []
|
|
41
41
|
common_set["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7]
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
42
|
+
|
|
43
|
+
if config.stat_cal_mode and config.device_stat_precision_mode:
|
|
44
|
+
e2e_set = {
|
|
45
|
+
"enable": not config.async_dump,
|
|
46
|
+
"trans_flag": True,
|
|
47
|
+
"stat_calc_mode": config.stat_cal_mode,
|
|
48
|
+
"device_stat_precision_mode": config.device_stat_precision_mode
|
|
49
|
+
}
|
|
50
|
+
else:
|
|
51
|
+
e2e_set = {
|
|
52
|
+
"enable": not config.async_dump,
|
|
53
|
+
"trans_flag": True
|
|
54
|
+
}
|
|
48
55
|
|
|
49
56
|
if config.list:
|
|
50
57
|
common_set["dump_mode"] = 1
|
|
@@ -23,13 +23,14 @@
|
|
|
23
23
|
|
|
24
24
|
namespace py = pybind11;
|
|
25
25
|
|
|
26
|
-
HookDynamicLoader &HookDynamicLoader::GetInstance()
|
|
26
|
+
HookDynamicLoader &HookDynamicLoader::GetInstance()
|
|
27
27
|
{
|
|
28
28
|
static HookDynamicLoader instance;
|
|
29
29
|
return instance;
|
|
30
30
|
}
|
|
31
31
|
|
|
32
|
-
bool HookDynamicLoader::LoadFunction(void *handle, const std::string &functionName)
|
|
32
|
+
bool HookDynamicLoader::LoadFunction(void *handle, const std::string &functionName)
|
|
33
|
+
{
|
|
33
34
|
void *func = dlsym(handle, functionName.c_str());
|
|
34
35
|
if (!func) {
|
|
35
36
|
MS_LOG(WARNING) << "Could not load function: " << functionName << ", error: " << dlerror();
|
|
@@ -83,7 +84,7 @@ bool HookDynamicLoader::LoadLibrary()
|
|
|
83
84
|
return true;
|
|
84
85
|
}
|
|
85
86
|
|
|
86
|
-
bool HookDynamicLoader::UnloadLibrary()
|
|
87
|
+
bool HookDynamicLoader::UnloadLibrary()
|
|
87
88
|
{
|
|
88
89
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
89
90
|
if (!handle_) {
|
|
@@ -103,8 +104,8 @@ void *HookDynamicLoader::GetHooker(const std::string &funcName)
|
|
|
103
104
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
104
105
|
auto iter = funcMap_.find(funcName);
|
|
105
106
|
if (iter == funcMap_.end()) {
|
|
106
|
-
|
|
107
|
-
|
|
107
|
+
MS_LOG(WARNING) << "Function not found: " << funcName;
|
|
108
|
+
return nullptr;
|
|
108
109
|
}
|
|
109
110
|
return iter->second;
|
|
110
111
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
|
|
2
|
-
* Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved.
|
|
3
3
|
*
|
|
4
4
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
5
|
* you may not use this file except in compliance with the License.
|
|
@@ -39,11 +39,11 @@ else:
|
|
|
39
39
|
pijit_label = True
|
|
40
40
|
|
|
41
41
|
|
|
42
|
-
class MindsporeService(BaseService):
|
|
42
|
+
class MindsporeService(BaseService):
|
|
43
43
|
@property
|
|
44
44
|
def _get_framework_type(self):
|
|
45
45
|
return Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
|
|
46
|
-
|
|
46
|
+
|
|
47
47
|
@staticmethod
|
|
48
48
|
def _get_current_rank():
|
|
49
49
|
return get_rank_if_initialized()
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
from mindspore import nn
|
|
18
18
|
from mindspore import communication
|
|
19
|
-
from msprobe.
|
|
19
|
+
from msprobe.core.common.log import logger
|
|
20
20
|
from msprobe.mindspore.common.utils import is_mindtorch
|
|
21
21
|
if is_mindtorch():
|
|
22
22
|
import torch
|
|
@@ -28,11 +28,12 @@ from mindspore import nn, _no_grad
|
|
|
28
28
|
from msprobe.core.common.log import logger
|
|
29
29
|
from msprobe.core.common.const import MonitorConst, Const
|
|
30
30
|
from msprobe.core.common.file_utils import load_json, save_json
|
|
31
|
+
from msprobe.core.monitor.utils import validate_config, get_output_base_dir, get_target_output_dir
|
|
31
32
|
from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter
|
|
32
33
|
from msprobe.mindspore.common.utils import is_mindtorch
|
|
33
34
|
from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank
|
|
34
|
-
from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name,
|
|
35
|
-
|
|
35
|
+
from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, step_accumulates_one, is_skip_step, \
|
|
36
|
+
get_metrics
|
|
36
37
|
from msprobe.mindspore.monitor.optimizer_collect import OptimizerMonFactory
|
|
37
38
|
from msprobe.mindspore.monitor.data_writers import CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
38
39
|
from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate
|
|
@@ -250,7 +251,6 @@ class TrainerMon:
|
|
|
250
251
|
self.has_collect_times = 0 # 重设采集计数器
|
|
251
252
|
self.print_struct = self.config.get("print_struct", False)
|
|
252
253
|
self.targets = self.config.get("targets", None)
|
|
253
|
-
self.is_select = self.config.get("is_select", False)
|
|
254
254
|
self.module_rank_list = self.config.get("module_ranks", [])
|
|
255
255
|
self.format = self.config.get('format', MonitorConst.CSV) # only csv supported in mindspore
|
|
256
256
|
self.eps = self.config.get('eps', 1e-8)
|