mindstudio-probe 8.2.0__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +63 -61
- msprobe/README.md +4 -4
- msprobe/core/common/const.py +6 -0
- msprobe/core/common/db_manager.py +35 -4
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/utils.py +14 -3
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +16 -4
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/analyzer.py +8 -7
- msprobe/core/compare/find_first/graph.py +11 -3
- msprobe/core/compare/find_first/utils.py +3 -2
- msprobe/core/compare/highlight.py +13 -6
- msprobe/core/compare/multiprocessing_compute.py +17 -10
- msprobe/core/compare/utils.py +14 -5
- msprobe/core/data_dump/data_collector.py +18 -21
- msprobe/core/data_dump/data_processor/pytorch_processor.py +43 -20
- msprobe/core/data_dump/json_writer.py +18 -8
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +21 -0
- msprobe/core/service.py +2 -0
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +7 -5
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/06.data_dump_MindSpore.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +46 -5
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +2 -0
- msprobe/docs/21.visualization_PyTorch.md +15 -80
- msprobe/docs/22.visualization_MindSpore.md +20 -104
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/cell_processor.py +33 -5
- msprobe/mindspore/compare/common_dir_compare.py +22 -26
- msprobe/mindspore/debugger/precision_debugger.py +1 -1
- msprobe/mindspore/dump/cell_dump_process.py +73 -62
- msprobe/mindspore/dump/graph_mode_cell_dump.py +21 -10
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +2 -0
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +15 -8
- msprobe/pytorch/monitor/module_hook.py +28 -9
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/visualization/builder/graph_builder.py +169 -64
- msprobe/visualization/builder/graph_merger.py +0 -1
- msprobe/visualization/builder/msprobe_adapter.py +1 -1
- msprobe/visualization/db_utils.py +25 -2
- msprobe/visualization/graph/base_node.py +0 -24
- msprobe/visualization/graph/graph.py +5 -14
- msprobe/visualization/graph_service.py +29 -53
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
|
@@ -14,7 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
-
|
|
17
|
+
import glob
|
|
18
|
+
import tempfile
|
|
18
19
|
import mindspore as ms
|
|
19
20
|
from mindspore import hal, ops, Tensor
|
|
20
21
|
from mindspore.ops.primitive import _run_op
|
|
@@ -28,6 +29,7 @@ import msprobe.mindspore.dump.cell_dump_process as cellDumperWithDumpGradient
|
|
|
28
29
|
import msprobe.mindspore.dump.cell_dump_with_insert_gradient as cellDumperWithInsertGradient
|
|
29
30
|
|
|
30
31
|
tensordump_flag = True
|
|
32
|
+
DEFAULT_RANK_DIR = "rank0"
|
|
31
33
|
try:
|
|
32
34
|
from mindspore._c_expression import _tensordump_set_step
|
|
33
35
|
except ImportError:
|
|
@@ -41,8 +43,6 @@ except ImportError:
|
|
|
41
43
|
|
|
42
44
|
|
|
43
45
|
class GraphModeCellDump:
|
|
44
|
-
task = CoreConst.STATISTICS
|
|
45
|
-
|
|
46
46
|
def __init__(self, config: DebuggerConfig, model, strict=True):
|
|
47
47
|
self.net = model
|
|
48
48
|
self.white_list = []
|
|
@@ -55,29 +55,40 @@ class GraphModeCellDump:
|
|
|
55
55
|
self.list = config.list
|
|
56
56
|
self.data_mode = config.data_mode
|
|
57
57
|
self.file_format = config.file_format
|
|
58
|
-
GraphModeCellDump.task = config.task
|
|
59
58
|
self.summary_mode = config.summary_mode
|
|
59
|
+
self.task = config.task
|
|
60
60
|
self.check_config(strict)
|
|
61
61
|
self.set_step()
|
|
62
62
|
|
|
63
63
|
@staticmethod
|
|
64
|
-
def step():
|
|
64
|
+
def step(dump_path, step_list, task):
|
|
65
65
|
# 更新TensorDump Step
|
|
66
|
-
if
|
|
66
|
+
if task == CoreConst.TENSOR:
|
|
67
67
|
hal.synchronize()
|
|
68
68
|
temp_tensor = ms.Tensor([1], dtype=ms.float32)
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
69
|
+
rank_id = os.environ.get('RANK_ID')
|
|
70
|
+
rank_dir = DEFAULT_RANK_DIR
|
|
71
|
+
|
|
72
|
+
if rank_id is not None:
|
|
73
|
+
rank_dir = CoreConst.RANK + str(rank_id)
|
|
74
|
+
|
|
75
|
+
with tempfile.TemporaryDirectory(dir=dump_path, prefix=rank_dir) as temp_dir:
|
|
76
|
+
save_file_flag = f"{temp_dir}/step_{Runtime.step_count}"
|
|
77
|
+
_run_op(ops.TensorDump(), "TensorDump", (save_file_flag, temp_tensor))
|
|
78
|
+
step_flag = "<tensordump-update-step>"
|
|
79
|
+
_run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor))
|
|
80
|
+
ops.tensordump(step_flag, temp_tensor)
|
|
81
|
+
cellDumperWithDumpGradient.process_step(dump_path, temp_dir, Runtime.step_count, step_list)
|
|
72
82
|
|
|
73
83
|
# 更新静态图KBK dump的step数
|
|
74
|
-
if
|
|
84
|
+
if task == CoreConst.STATISTICS:
|
|
75
85
|
if not graph_step_flag:
|
|
76
86
|
raise Exception(
|
|
77
87
|
"Importing _dump_step failed, "
|
|
78
88
|
"please use the latest version package of MindSpore."
|
|
79
89
|
)
|
|
80
90
|
_dump_step(1)
|
|
91
|
+
cellDumperWithDumpGradient.process_statistics_step(dump_path, Runtime.step_count, step_list)
|
|
81
92
|
|
|
82
93
|
def check_config(self, strict):
|
|
83
94
|
if not self.net:
|
|
@@ -203,10 +203,12 @@ class MindsporeHookManager(BaseHookManager):
|
|
|
203
203
|
return
|
|
204
204
|
|
|
205
205
|
with ThreadSafe():
|
|
206
|
+
original_state = self.ensure_gc_enabled()
|
|
206
207
|
BaseHookManager.inner_switch[tid] = True
|
|
207
208
|
module_input = ModuleBackwardInputs(grad_input=grad_input)
|
|
208
209
|
self.data_collector.update_api_or_module_name(full_name)
|
|
209
210
|
self.data_collector.backward_input_data_collect(full_name, module, self._pid, module_input)
|
|
210
211
|
BaseHookManager.inner_switch[tid] = False
|
|
212
|
+
self.restore_gc_state(original_state)
|
|
211
213
|
|
|
212
214
|
return backward_pre_hook
|
msprobe/pytorch/compare/utils.py
CHANGED
|
@@ -35,7 +35,8 @@ def read_pt_data(dir_path, file_name):
|
|
|
35
35
|
data_value = load_pt(data_path, to_cpu=True).detach()
|
|
36
36
|
except RuntimeError as e:
|
|
37
37
|
# 这里捕获 load_pt 中抛出的异常
|
|
38
|
-
|
|
38
|
+
data_path_file_name = os.path.basename(data_path)
|
|
39
|
+
logger.error(f"Failed to load the .pt file at {data_path_file_name}.")
|
|
39
40
|
raise CompareException(CompareException.INVALID_FILE_ERROR) from e
|
|
40
41
|
except AttributeError as e:
|
|
41
42
|
# 这里捕获 detach 方法抛出的异常
|
|
@@ -24,8 +24,11 @@ from msprobe.pytorch.common.log import logger
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def wrap_setup_backward_hook(func):
|
|
27
|
-
def requires_clone(tensor):
|
|
28
|
-
|
|
27
|
+
def requires_clone(tensor, need_check_leaf=False):
|
|
28
|
+
need_clone = isinstance(tensor, torch.Tensor) and tensor.requires_grad and torch.is_grad_enabled()
|
|
29
|
+
if need_check_leaf:
|
|
30
|
+
need_clone &= tensor.grad_fn is not None
|
|
31
|
+
return need_clone
|
|
29
32
|
|
|
30
33
|
@recursion_depth_decorator("Dump: wrap_setup_backward_hook.parse_tensor", max_depth=Const.DUMP_MAX_DEPTH)
|
|
31
34
|
def parse_tensor(item, tensor_list):
|
|
@@ -39,20 +42,20 @@ def wrap_setup_backward_hook(func):
|
|
|
39
42
|
parse_tensor(value, tensor_list)
|
|
40
43
|
|
|
41
44
|
@recursion_depth_decorator("Dump: wrap_setup_backward_hook.rebuild_args", max_depth=Const.DUMP_MAX_DEPTH)
|
|
42
|
-
def rebuild_args(item, tensor_iter):
|
|
43
|
-
if requires_clone(item):
|
|
45
|
+
def rebuild_args(item, tensor_iter, need_check_leaf=False):
|
|
46
|
+
if requires_clone(item, need_check_leaf):
|
|
44
47
|
result = next(tensor_iter)
|
|
45
48
|
if hasattr(result, "_base") and result._base is not None:
|
|
46
49
|
if torch._C._autograd._get_creation_meta(result) != torch._C._autograd.CreationMeta(0):
|
|
47
50
|
torch._C._autograd._set_creation_meta(result, torch._C._autograd.CreationMeta(0))
|
|
48
|
-
return result
|
|
51
|
+
return result
|
|
49
52
|
if isinstance(item, list):
|
|
50
53
|
for index, value in enumerate(item):
|
|
51
|
-
item[index] = rebuild_args(value, tensor_iter)
|
|
54
|
+
item[index] = rebuild_args(value, tensor_iter, need_check_leaf=True)
|
|
52
55
|
return item
|
|
53
56
|
if isinstance(item, dict):
|
|
54
57
|
for key, value in item.items():
|
|
55
|
-
item[key] = rebuild_args(value, tensor_iter)
|
|
58
|
+
item[key] = rebuild_args(value, tensor_iter, need_check_leaf=True)
|
|
56
59
|
return item
|
|
57
60
|
if isinstance(item, tuple):
|
|
58
61
|
if hasattr(item, '_fields'):
|
|
@@ -23,13 +23,15 @@ from torch.utils.hooks import BackwardHook, RemovableHandle
|
|
|
23
23
|
from msprobe.core.common.const import Const
|
|
24
24
|
from msprobe.core.common.runtime import Runtime
|
|
25
25
|
from msprobe.core.common.utils import ModuleQueue, ThreadSafe
|
|
26
|
+
from msprobe.core.common.megatron_utils import wrap_megatron_step, get_micro_step, is_megatron
|
|
26
27
|
from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
|
|
27
28
|
from msprobe.pytorch.common.log import logger
|
|
28
29
|
from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook
|
|
29
30
|
from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_input_output_hook
|
|
30
31
|
|
|
31
32
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
32
|
-
|
|
33
|
+
torch_version_above_or_equal_21 = torch.__version__.split('+')[0] >= '2.1'
|
|
34
|
+
if torch_version_above_or_equal_21:
|
|
33
35
|
from torch.utils.checkpoint import _StopRecomputationError
|
|
34
36
|
|
|
35
37
|
|
|
@@ -61,7 +63,7 @@ def wrap_forward_with_hook_safety(module):
|
|
|
61
63
|
hook_fn = list(module._forward_hooks.values())[0]
|
|
62
64
|
hook_fn(module, args, kwargs, exception_output)
|
|
63
65
|
raise e
|
|
64
|
-
if
|
|
66
|
+
if torch_version_above_or_equal_21:
|
|
65
67
|
module.forward = wrapped_forward
|
|
66
68
|
|
|
67
69
|
|
|
@@ -82,6 +84,8 @@ class ModuleProcesser:
|
|
|
82
84
|
from megatron.core.pipeline_parallel import schedules
|
|
83
85
|
origin_func_id = id(schedules.deallocate_output_tensor)
|
|
84
86
|
schedules.deallocate_output_tensor = wrap_megatron_deallocate(schedules.deallocate_output_tensor)
|
|
87
|
+
schedules.forward_step = wrap_megatron_step(schedules.forward_step)
|
|
88
|
+
schedules.backward_step = wrap_megatron_step(schedules.backward_step, is_forward=False)
|
|
85
89
|
for module in list(sys.modules.values()):
|
|
86
90
|
if module.__name__ == 'schedules':
|
|
87
91
|
continue
|
|
@@ -258,14 +262,16 @@ class ModuleProcesser:
|
|
|
258
262
|
ModuleProcesser.module_stack[tid] = []
|
|
259
263
|
|
|
260
264
|
if self.module_stack[tid]:
|
|
261
|
-
ModuleProcesser.module_node[full_name] = self.module_stack[tid][-1]
|
|
265
|
+
ModuleProcesser.module_node[full_name] = self.module_stack[tid][-1] if not is_megatron() \
|
|
266
|
+
else [self.module_stack[tid][-1], get_micro_step()]
|
|
262
267
|
else:
|
|
263
268
|
parent_name = ModuleProcesser.module_queue.find_last(full_name)
|
|
264
|
-
ModuleProcesser.module_node[full_name] = parent_name
|
|
269
|
+
ModuleProcesser.module_node[full_name] = parent_name if not is_megatron() \
|
|
270
|
+
else [parent_name, get_micro_step()]
|
|
265
271
|
|
|
266
272
|
ModuleProcesser.module_queue.add_name(full_name)
|
|
267
273
|
ModuleProcesser.module_stack[tid].append(full_name)
|
|
268
|
-
ModuleProcesser.api_parent_node[tid] = full_name
|
|
274
|
+
ModuleProcesser.api_parent_node[tid] = full_name if not is_megatron() else [full_name, get_micro_step()]
|
|
269
275
|
if self.scope:
|
|
270
276
|
self.scope.begin_module(full_name)
|
|
271
277
|
|
|
@@ -273,14 +279,15 @@ class ModuleProcesser:
|
|
|
273
279
|
tid = threading.get_ident()
|
|
274
280
|
if torch_version_above_or_equal_2 or is_forward:
|
|
275
281
|
ModuleProcesser.module_queue.remove_name(full_name)
|
|
276
|
-
ModuleProcesser.api_parent_node[tid] = None
|
|
282
|
+
ModuleProcesser.api_parent_node[tid] = None if not is_megatron() else [None, get_micro_step()]
|
|
277
283
|
if self.module_stack.get(tid):
|
|
278
284
|
ModuleProcesser.module_stack[tid].pop()
|
|
279
285
|
if self.module_stack.get(tid):
|
|
280
|
-
ModuleProcesser.api_parent_node[tid] = ModuleProcesser.module_stack[tid][-1]
|
|
286
|
+
ModuleProcesser.api_parent_node[tid] = ModuleProcesser.module_stack[tid][-1] if not is_megatron() \
|
|
287
|
+
else [ModuleProcesser.module_stack[tid][-1], get_micro_step()]
|
|
281
288
|
if self.scope:
|
|
282
289
|
self.scope.end_module(full_name)
|
|
283
290
|
else:
|
|
284
291
|
if self.scope:
|
|
285
292
|
self.scope.begin_module(full_name)
|
|
286
|
-
ModuleProcesser.api_parent_node[tid] = full_name
|
|
293
|
+
ModuleProcesser.api_parent_node[tid] = full_name if not is_megatron() else [full_name, get_micro_step()]
|
|
@@ -19,12 +19,14 @@ import importlib
|
|
|
19
19
|
from collections import defaultdict
|
|
20
20
|
from datetime import datetime
|
|
21
21
|
from functools import partial
|
|
22
|
+
from itertools import cycle
|
|
22
23
|
|
|
23
24
|
import pytz
|
|
24
25
|
import torch
|
|
25
26
|
import torch.distributed as dist
|
|
26
27
|
import pandas as pd
|
|
27
28
|
from torch.utils.hooks import BackwardHook
|
|
29
|
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
28
30
|
|
|
29
31
|
from msprobe.core.common.const import MonitorConst, Const
|
|
30
32
|
from msprobe.core.common.file_utils import load_json, save_json, make_dir
|
|
@@ -229,6 +231,8 @@ class TrainerMon:
|
|
|
229
231
|
self.duplicate_param = {}
|
|
230
232
|
self.name2tag = {}
|
|
231
233
|
self.param_name_call_id = {}
|
|
234
|
+
self.flat_prefix_names = []
|
|
235
|
+
self.flat_prefix_reverse_iter = None
|
|
232
236
|
self.call_id = 0
|
|
233
237
|
self.module_struct = defaultdict(dict)
|
|
234
238
|
self.grad_accs = []
|
|
@@ -945,13 +949,20 @@ class TrainerMon:
|
|
|
945
949
|
return False
|
|
946
950
|
|
|
947
951
|
def _register_chunk(self, model_chunk, prefix):
|
|
952
|
+
if isinstance(model_chunk, FSDP):
|
|
953
|
+
if not model_chunk._use_orig_params:
|
|
954
|
+
raise ValueError("Only Support fsdp1 with use_orig_params=True")
|
|
955
|
+
self.fsdp_wrapped_module = True
|
|
948
956
|
for (param_name, param) in model_chunk.named_parameters():
|
|
949
957
|
if not param.requires_grad:
|
|
950
958
|
continue
|
|
951
|
-
if not self.fsdp_wrapped_module and param_name.startswith("_fsdp_wrapped_module"):
|
|
952
|
-
self.fsdp_wrapped_module = True
|
|
953
959
|
if not self.fsdp2_wrapped_module and param.__class__.__name__ == "DTensor":
|
|
954
960
|
self.fsdp2_wrapped_module = True
|
|
961
|
+
if self.fsdp_wrapped_module: # FSDP1需要记录完整的不被target限制的flat权重前缀名,以供后续对flat解包
|
|
962
|
+
flat_prefix_name, _ = param_name.rsplit(MonitorConst.FSDP_FLAT_SEP, 1)
|
|
963
|
+
if flat_prefix_name not in self.flat_prefix_names:
|
|
964
|
+
self.flat_prefix_names.append(flat_prefix_name)
|
|
965
|
+
|
|
955
966
|
if self._is_target_param(param_name, param, prefix):
|
|
956
967
|
name = prefix + squash_param_name(param_name, self.squash_name)
|
|
957
968
|
if name in self.param2name.values():
|
|
@@ -975,6 +986,8 @@ class TrainerMon:
|
|
|
975
986
|
k: get_summary_writer_tag_name(name, k, self.rank)
|
|
976
987
|
for k in keywords
|
|
977
988
|
}
|
|
989
|
+
if self.fsdp_wrapped_module:
|
|
990
|
+
self.flat_prefix_reverse_iter = cycle(reversed(self.flat_prefix_names)) # post_backward_hook调用顺序是反向的
|
|
978
991
|
|
|
979
992
|
def _register_param_name(self):
|
|
980
993
|
for vpp_stage, model_chunk in enumerate(self.model):
|
|
@@ -1224,17 +1237,22 @@ class TrainerMon:
|
|
|
1224
1237
|
每个forward阶段,fsdp对AccumulateGrad重复注册hook方法,monitor工具内注册hook无法生效,
|
|
1225
1238
|
因此对_post_backward_hook进行patch,在backward后,reduce_scatter前采集梯度。
|
|
1226
1239
|
"""
|
|
1240
|
+
|
|
1227
1241
|
def patch_post_backward_hook(_post_backward_hook):
|
|
1228
1242
|
def wrapper(state, handle, *unused):
|
|
1229
1243
|
grad_dict = {}
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1244
|
+
local_names = handle.flat_param._fqns
|
|
1245
|
+
offsets = handle._get_flat_param_offsets()
|
|
1246
|
+
shapes = handle.flat_param._shapes
|
|
1247
|
+
flat_prefix = next(self.flat_prefix_reverse_iter)
|
|
1248
|
+
for local_name, (start, end), local_shape in zip(local_names, offsets, shapes):
|
|
1249
|
+
grad_clip = handle.flat_param.grad[start:end + 1]
|
|
1250
|
+
grad = grad_clip.reshape(local_shape)
|
|
1251
|
+
total_name = f"{flat_prefix}{MonitorConst.FSDP_FLAT_SEP}{local_name}"
|
|
1252
|
+
if total_name not in self.origin2squash:
|
|
1253
|
+
logger.warning(f"{total_name} not in model.named_parameters(), skip.")
|
|
1234
1254
|
continue
|
|
1235
|
-
|
|
1236
|
-
offset += limit
|
|
1237
|
-
tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
|
|
1255
|
+
tag = self.name2tag.get(self.origin2squash[total_name], {}).get(MonitorConst.PRE_GRAD)
|
|
1238
1256
|
if tag is None:
|
|
1239
1257
|
continue
|
|
1240
1258
|
grad_dict[tag] = grad
|
|
@@ -1242,6 +1260,7 @@ class TrainerMon:
|
|
|
1242
1260
|
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
1243
1261
|
out = _post_backward_hook(state, handle, *unused)
|
|
1244
1262
|
return out
|
|
1263
|
+
|
|
1245
1264
|
return wrapper
|
|
1246
1265
|
|
|
1247
1266
|
logger.info("Patch fsdp _post_backward_hook, collect pre_grad metrics.")
|
|
@@ -17,7 +17,7 @@ import json
|
|
|
17
17
|
import os
|
|
18
18
|
import time
|
|
19
19
|
import multiprocessing
|
|
20
|
-
from multiprocessing import Pool
|
|
20
|
+
from multiprocessing import Pool, Lock
|
|
21
21
|
|
|
22
22
|
import torch
|
|
23
23
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
@@ -39,6 +39,7 @@ from msprobe.pytorch.online_dispatch.utils import get_callstack, data_to_cpu, ge
|
|
|
39
39
|
from msprobe.pytorch.online_dispatch.compare import Comparator
|
|
40
40
|
from msprobe.core.common.utils import check_str_param, safe_get_value
|
|
41
41
|
|
|
42
|
+
child_global_lock = None
|
|
42
43
|
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
43
44
|
RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
|
|
44
45
|
DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
|
|
@@ -86,14 +87,14 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
86
87
|
yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
|
|
87
88
|
self.get_ops(yaml_path)
|
|
88
89
|
|
|
89
|
-
self.lock = None
|
|
90
|
+
self.lock = Lock() if process_num > 0 else None
|
|
90
91
|
max_process_num = max(int((multiprocessing.cpu_count() + 1) // Const.CPU_QUARTER), 1)
|
|
91
92
|
if process_num > max_process_num:
|
|
92
93
|
logger.error(f"process_num should be less than or equal to {max_process_num}, but got {process_num}!")
|
|
93
94
|
raise DispatchException(f'process_num should be less than or equal to {max_process_num}, '
|
|
94
95
|
f'but got {process_num}!')
|
|
95
96
|
if process_num > 0:
|
|
96
|
-
self.pool = Pool(process_num)
|
|
97
|
+
self.pool = Pool(process_num, initializer=self._init_child_process, initargs=(self.lock,))
|
|
97
98
|
if debug:
|
|
98
99
|
logger.info(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} '
|
|
99
100
|
f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], '
|
|
@@ -114,18 +115,17 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
114
115
|
logger.error("Please check train log, An exception may have occurred!")
|
|
115
116
|
return
|
|
116
117
|
check_file_or_directory_path(summary_path, False)
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
fp_handle.close()
|
|
118
|
+
with FileOpen(summary_path, "r") as fp_handle:
|
|
119
|
+
while True:
|
|
120
|
+
json_line_data = fp_handle.readline()
|
|
121
|
+
if json_line_data == '\n':
|
|
122
|
+
continue
|
|
123
|
+
if len(json_line_data) == 0:
|
|
124
|
+
break
|
|
125
|
+
msg = json.loads(json_line_data)
|
|
126
|
+
if len(msg) < 2:
|
|
127
|
+
raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.")
|
|
128
|
+
self.all_summary[msg[0]] = msg[1]
|
|
129
129
|
|
|
130
130
|
if self.debug_flag:
|
|
131
131
|
input_num = 0
|
|
@@ -163,11 +163,16 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
163
163
|
|
|
164
164
|
call_stack = get_callstack()
|
|
165
165
|
self.call_stack_list.append(call_stack)
|
|
166
|
-
|
|
167
|
-
if
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
self.single_api_index_dict
|
|
166
|
+
|
|
167
|
+
self.lock.acquire() if self.process_num > 0 else None
|
|
168
|
+
try:
|
|
169
|
+
self.api_index += 1
|
|
170
|
+
if aten_api not in self.single_api_index_dict:
|
|
171
|
+
self.single_api_index_dict[aten_api] = 1
|
|
172
|
+
else:
|
|
173
|
+
self.single_api_index_dict[aten_api] += 1
|
|
174
|
+
finally:
|
|
175
|
+
self.lock.release() if self.process_num > 0 else None
|
|
171
176
|
|
|
172
177
|
run_param = self.get_run_param(aten_api, func.__name__, aten_api_overload_name)
|
|
173
178
|
|
|
@@ -180,7 +185,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
180
185
|
cpu_kwargs = []
|
|
181
186
|
data_to_cpu(args, 0, cpu_args)
|
|
182
187
|
data_to_cpu(kwargs, 0, cpu_kwargs)
|
|
183
|
-
|
|
188
|
+
|
|
184
189
|
cpu_args = safe_get_value(cpu_args, 0, "cpu_args")
|
|
185
190
|
cpu_kwargs = safe_get_value(cpu_kwargs, 0, "cpu_kwargs")
|
|
186
191
|
|
|
@@ -194,7 +199,12 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
194
199
|
try:
|
|
195
200
|
cpu_out = func(*cpu_args, **cpu_kwargs)
|
|
196
201
|
except RuntimeError as e:
|
|
197
|
-
self.
|
|
202
|
+
self.lock.acquire() if self.process_num > 0 else None
|
|
203
|
+
try:
|
|
204
|
+
self.api_index -= 1
|
|
205
|
+
self.single_api_index_dict[aten_api] -= 1
|
|
206
|
+
finally:
|
|
207
|
+
self.lock.release() if self.process_num > 0 else None
|
|
198
208
|
logger.warning(f"RuntimeError: {e}")
|
|
199
209
|
logger.warning(f"This aten_api {aten_api} does not support running on cpu, so skip it.")
|
|
200
210
|
return npu_out
|
|
@@ -215,7 +225,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
215
225
|
run_param.process_flag = True
|
|
216
226
|
if self.check_fun(func, run_param):
|
|
217
227
|
data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, None, npu_out_cpu, cpu_out,
|
|
218
|
-
|
|
228
|
+
child_global_lock)
|
|
219
229
|
self.pool.apply_async(func=dispatch_multiprocess, args=(run_param, data_info),
|
|
220
230
|
error_callback=error_call)
|
|
221
231
|
else:
|
|
@@ -233,12 +243,20 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
233
243
|
return True
|
|
234
244
|
return False
|
|
235
245
|
|
|
246
|
+
@staticmethod
|
|
247
|
+
def _init_child_process(lock):
|
|
248
|
+
global child_global_lock
|
|
249
|
+
child_global_lock = lock
|
|
250
|
+
|
|
236
251
|
def get_dir_name(self, tag):
|
|
237
252
|
# guarantee file uniqueness
|
|
238
253
|
time.sleep(1)
|
|
239
|
-
|
|
254
|
+
# 时间格式:年-月-日-时-分-秒-毫秒(精确到千分之一秒)
|
|
255
|
+
time_now = time.strftime("%Y%m%d%H%M%S%f", time.localtime(time.time()))[:-3] # 取前3位毫秒
|
|
256
|
+
|
|
240
257
|
if tag is None or not isinstance(tag, str):
|
|
241
258
|
logger.warning('There is not tag or the type of tag is not string.')
|
|
259
|
+
# 目录名格式:msprobe_rank{设备ID}_{毫秒时间戳}
|
|
242
260
|
dir_name = f'msprobe_rank{self.device_id}_{time_now}'
|
|
243
261
|
else:
|
|
244
262
|
dir_name = f'msprobe_{tag}_rank{self.device_id}_{time_now}'
|