mindstudio-probe 8.2.1__py3-none-any.whl → 8.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/METADATA +1 -1
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/RECORD +39 -40
- msprobe/README.md +7 -2
- msprobe/core/common/const.py +17 -3
- msprobe/core/common/file_utils.py +138 -32
- msprobe/core/common/framework_adapter.py +16 -6
- msprobe/core/common/utils.py +17 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +4 -16
- msprobe/core/compare/find_first/utils.py +1 -1
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +6 -1
- msprobe/core/hook_manager.py +0 -1
- msprobe/docs/01.installation.md +2 -0
- msprobe/docs/02.config_introduction.md +1 -1
- msprobe/docs/14.data_parse_PyTorch.md +2 -0
- msprobe/docs/15.free_benchmarking_PyTorch.md +1 -1
- msprobe/docs/21.visualization_PyTorch.md +1 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +3 -3
- msprobe/docs/32.ckpt_compare.md +5 -5
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/mindspore/compare/utils.py +1 -2
- msprobe/mindspore/monitor/module_hook.py +17 -20
- msprobe/msprobe.py +6 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +34 -5
- msprobe/pytorch/common/utils.py +2 -52
- msprobe/pytorch/compare/utils.py +1 -2
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +24 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +27 -6
- msprobe/pytorch/hook_module/api_register.py +11 -2
- msprobe/pytorch/monitor/module_hook.py +16 -34
- msprobe/pytorch/pt_config.py +6 -0
- msprobe/visualization/builder/graph_builder.py +3 -2
- msprobe/visualization/builder/graph_merger.py +13 -0
- msprobe/visualization/graph/graph.py +13 -9
- msprobe/visualization/utils.py +11 -1
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +0 -3
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/top_level.txt +0 -0
|
@@ -13,21 +13,25 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import threading
|
|
17
16
|
import sys
|
|
17
|
+
import threading
|
|
18
18
|
from collections import OrderedDict
|
|
19
19
|
|
|
20
20
|
import torch
|
|
21
21
|
from torch.utils.hooks import BackwardHook, RemovableHandle
|
|
22
22
|
|
|
23
23
|
from msprobe.core.common.const import Const
|
|
24
|
+
from msprobe.core.common.megatron_utils import wrap_megatron_step, get_micro_step, is_megatron
|
|
24
25
|
from msprobe.core.common.runtime import Runtime
|
|
25
26
|
from msprobe.core.common.utils import ModuleQueue, ThreadSafe
|
|
26
|
-
from msprobe.core.common.megatron_utils import wrap_megatron_step, get_micro_step, is_megatron
|
|
27
27
|
from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
|
|
28
28
|
from msprobe.pytorch.common.log import logger
|
|
29
29
|
from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook
|
|
30
|
-
from msprobe.pytorch.dump.module_dump.hook_wrapper import
|
|
30
|
+
from msprobe.pytorch.dump.module_dump.hook_wrapper import (
|
|
31
|
+
wrap_setup_input_output_hook,
|
|
32
|
+
wrap_backward_hook_function_apply
|
|
33
|
+
)
|
|
34
|
+
|
|
31
35
|
|
|
32
36
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
33
37
|
torch_version_above_or_equal_21 = torch.__version__.split('+')[0] >= '2.1'
|
|
@@ -59,10 +63,13 @@ def wrap_forward_with_hook_safety(module):
|
|
|
59
63
|
except _StopRecomputationError as e:
|
|
60
64
|
exception_output = None
|
|
61
65
|
if len(module._forward_hooks.values()) > 0:
|
|
62
|
-
# msprobe的forward_hook
|
|
63
|
-
hook_fn
|
|
64
|
-
|
|
66
|
+
# 仅执行msprobe的forward_hook, hook名称必然包含'ModuleProcesser.'
|
|
67
|
+
for hook_fn in module._forward_hooks.values():
|
|
68
|
+
if 'ModuleProcesser' in str(hook_fn):
|
|
69
|
+
hook_fn(module, args, kwargs, exception_output)
|
|
70
|
+
break
|
|
65
71
|
raise e
|
|
72
|
+
|
|
66
73
|
if torch_version_above_or_equal_21:
|
|
67
74
|
module.forward = wrapped_forward
|
|
68
75
|
|
|
@@ -80,6 +87,7 @@ class ModuleProcesser:
|
|
|
80
87
|
def __init__(self, scope):
|
|
81
88
|
self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
|
|
82
89
|
wrap_setup_input_output_hook()
|
|
90
|
+
wrap_backward_hook_function_apply()
|
|
83
91
|
try:
|
|
84
92
|
from megatron.core.pipeline_parallel import schedules
|
|
85
93
|
origin_func_id = id(schedules.deallocate_output_tensor)
|
|
@@ -146,7 +154,13 @@ class ModuleProcesser:
|
|
|
146
154
|
modules_and_names_with_index = self.get_modules_and_names(models, recursive, module_names)
|
|
147
155
|
for index, modules_and_names in modules_and_names_with_index.items():
|
|
148
156
|
model = models if index == "-1" else models[int(index)]
|
|
157
|
+
|
|
158
|
+
model_list = []
|
|
149
159
|
for name, module in modules_and_names:
|
|
160
|
+
model_list.append((name, module))
|
|
161
|
+
|
|
162
|
+
is_verl = "verl" in sys.modules
|
|
163
|
+
for idx, (name, module) in enumerate(model_list):
|
|
150
164
|
if recursive and module == model:
|
|
151
165
|
continue
|
|
152
166
|
if not is_torch_nn_module(module):
|
|
@@ -157,6 +171,13 @@ class ModuleProcesser:
|
|
|
157
171
|
continue
|
|
158
172
|
if module.__class__.__name__ == "FullyShardedDataParallel":
|
|
159
173
|
continue
|
|
174
|
+
|
|
175
|
+
# verl 场景下跳过第一层和最后一层
|
|
176
|
+
if is_verl and (idx == 1 or idx == len(model_list) - 1):
|
|
177
|
+
logger.warning(f"The module {name} is the first or last layer in verl scenario, "
|
|
178
|
+
f"the data dump for this module will be skipped.")
|
|
179
|
+
continue
|
|
180
|
+
|
|
160
181
|
setattr(module, 'msprobe_hook', True)
|
|
161
182
|
module_index = (index + Const.SEP) if index != "-1" else ""
|
|
162
183
|
prefix_name = f'{BaseScope.Module_Type_Module}{Const.SEP}{module_index}{name}{Const.SEP}' + \
|
|
@@ -43,7 +43,6 @@ else:
|
|
|
43
43
|
|
|
44
44
|
torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
|
|
45
45
|
|
|
46
|
-
_inner_used_api = {}
|
|
47
46
|
_supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
|
|
48
47
|
_cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
|
|
49
48
|
dist_data_collect_func = {}
|
|
@@ -85,6 +84,12 @@ if not is_gpu:
|
|
|
85
84
|
mindspeed_op_file_list = [op.split(Const.SEP)[0] + Const.PY_SUFFIX for op in mindspeed_op_list]
|
|
86
85
|
dynamic_import_op(mindspeed.ops, mindspeed_op_file_list)
|
|
87
86
|
|
|
87
|
+
_inner_used_api = {
|
|
88
|
+
Const.PT_FRAMEWORK + Const.SEP + Const.PT_API_TYPE_TENSOR: (
|
|
89
|
+
torch.Tensor, "view_as"
|
|
90
|
+
)
|
|
91
|
+
}
|
|
92
|
+
|
|
88
93
|
|
|
89
94
|
@parameter_adapter
|
|
90
95
|
def tensor_module_forward(module, *args, **kwargs):
|
|
@@ -130,13 +135,17 @@ def redirect_wait():
|
|
|
130
135
|
store_func = dist_data_collect_func.pop(args[0])
|
|
131
136
|
store_func()
|
|
132
137
|
return
|
|
138
|
+
remove_value = None
|
|
133
139
|
for value in dist_batch_data_collect_func:
|
|
134
140
|
if args[0] in value[0]:
|
|
135
141
|
value[0].remove(args[0])
|
|
136
142
|
if len(value[0]) == 0:
|
|
137
143
|
store_func = value[1]
|
|
138
144
|
store_func()
|
|
139
|
-
|
|
145
|
+
remove_value = value
|
|
146
|
+
break
|
|
147
|
+
if remove_value:
|
|
148
|
+
dist_batch_data_collect_func.remove(remove_value)
|
|
140
149
|
|
|
141
150
|
return wrapped_wait
|
|
142
151
|
|
|
@@ -48,12 +48,10 @@ from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_write
|
|
|
48
48
|
from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
|
|
49
49
|
from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
|
|
50
50
|
|
|
51
|
-
|
|
52
51
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
53
52
|
if not torch_version_above_or_equal_2:
|
|
54
53
|
raise ValueError("monitor require torch>=2.0")
|
|
55
54
|
|
|
56
|
-
|
|
57
55
|
FORMAT_MAPPING = {
|
|
58
56
|
MonitorConst.TENSORBOARD: SummaryWriterWithAD,
|
|
59
57
|
MonitorConst.CSV: CSVWriterWithAD,
|
|
@@ -150,15 +148,11 @@ class GradContext:
|
|
|
150
148
|
def __init__(self) -> None:
|
|
151
149
|
self.pre = {}
|
|
152
150
|
self.post = {}
|
|
153
|
-
self.acc_metric = {}
|
|
154
|
-
self.acc = {}
|
|
155
151
|
self.actv = {}
|
|
156
152
|
|
|
157
153
|
def reset(self):
|
|
158
154
|
self.pre.clear()
|
|
159
155
|
self.post.clear()
|
|
160
|
-
self.acc_metric.clear()
|
|
161
|
-
self.acc.clear()
|
|
162
156
|
self.actv.clear()
|
|
163
157
|
|
|
164
158
|
|
|
@@ -510,18 +504,8 @@ class TrainerMon:
|
|
|
510
504
|
if not self.wg_distribution:
|
|
511
505
|
return {}, {}
|
|
512
506
|
|
|
513
|
-
if self.weight_hooked:
|
|
514
|
-
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
515
|
-
|
|
516
507
|
get_metrics(self.ops, post_grad_dict, self.eps, self.grad_context.post)
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
if self.weight_hooked:
|
|
520
|
-
unreduced_grad = self.grad_context.acc_metric
|
|
521
|
-
else:
|
|
522
|
-
unreduced_grad = self.grad_context.pre
|
|
523
|
-
|
|
524
|
-
return reduced_grad, unreduced_grad
|
|
508
|
+
return self.grad_context.post, self.grad_context.pre
|
|
525
509
|
|
|
526
510
|
def generate_xy_metrics(self):
|
|
527
511
|
actv = {}
|
|
@@ -529,7 +513,6 @@ class TrainerMon:
|
|
|
529
513
|
actv.update(fwd_context.actv)
|
|
530
514
|
|
|
531
515
|
actv_grad = self.grad_context.actv
|
|
532
|
-
|
|
533
516
|
return actv, actv_grad
|
|
534
517
|
|
|
535
518
|
def reload_xy(self, xy_distribution=False):
|
|
@@ -607,11 +590,8 @@ class TrainerMon:
|
|
|
607
590
|
if not self.wg_distribution:
|
|
608
591
|
return
|
|
609
592
|
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
use_micro_step=self.monitor_mbs_grad)
|
|
613
|
-
else:
|
|
614
|
-
self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
|
|
593
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced',
|
|
594
|
+
use_micro_step=self.monitor_mbs_grad)
|
|
615
595
|
self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
|
|
616
596
|
|
|
617
597
|
def hook_optimizer(self, optimizer):
|
|
@@ -732,9 +712,9 @@ class TrainerMon:
|
|
|
732
712
|
# 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
|
|
733
713
|
if self.monitoring:
|
|
734
714
|
module_rank_valid = not self.module_rank_list or (
|
|
735
|
-
|
|
715
|
+
dist.is_initialized() and dist.get_rank() in self.module_rank_list)
|
|
736
716
|
step_condition = (context.step >= self.start_step and (
|
|
737
|
-
|
|
717
|
+
context.step - self.start_step) % self.step_interval == 0)
|
|
738
718
|
if module_rank_valid and step_condition:
|
|
739
719
|
self.has_collect_times += 1
|
|
740
720
|
|
|
@@ -791,6 +771,7 @@ class TrainerMon:
|
|
|
791
771
|
hook(optimizer, args, kwargs)
|
|
792
772
|
step_final_hook(optimizer, args, kwargs)
|
|
793
773
|
return out
|
|
774
|
+
|
|
794
775
|
return wrapper
|
|
795
776
|
|
|
796
777
|
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
@@ -1013,11 +994,11 @@ class TrainerMon:
|
|
|
1013
994
|
vpp_stage + module_name,
|
|
1014
995
|
]:
|
|
1015
996
|
if pattern in l2_targets:
|
|
1016
|
-
return pattern
|
|
997
|
+
return pattern
|
|
1017
998
|
elif hook_name in ["linear_hook"]:
|
|
1018
999
|
return vpp_stage + squash_param_name(module_name, self.squash_name)
|
|
1019
1000
|
return ""
|
|
1020
|
-
|
|
1001
|
+
|
|
1021
1002
|
def _hook_module(self, target_names, l2_target_names, module: torch.nn.Module, vpp_stage=''):
|
|
1022
1003
|
if '_modules' not in module.__dict__:
|
|
1023
1004
|
# nothing to hook
|
|
@@ -1151,7 +1132,7 @@ class TrainerMon:
|
|
|
1151
1132
|
context.micro_step = 0
|
|
1152
1133
|
context.step += 1
|
|
1153
1134
|
return
|
|
1154
|
-
|
|
1135
|
+
|
|
1155
1136
|
def stack_hook(module, args, kwargs, module_output, name):
|
|
1156
1137
|
if module not in self.module_fwd_hook_context_by_module:
|
|
1157
1138
|
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
@@ -1221,7 +1202,7 @@ class TrainerMon:
|
|
|
1221
1202
|
if self.monitor_mbs_grad:
|
|
1222
1203
|
self._hook_weights()
|
|
1223
1204
|
return
|
|
1224
|
-
|
|
1205
|
+
|
|
1225
1206
|
self.optimizer_mon.patch_grad_sync(self)
|
|
1226
1207
|
|
|
1227
1208
|
if self.enable_megatron or self.enable_deepspeed:
|
|
@@ -1281,6 +1262,7 @@ class TrainerMon:
|
|
|
1281
1262
|
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
1282
1263
|
out = foreach_reduce(fsdp_params, unsharded_grads, *unused)
|
|
1283
1264
|
return out
|
|
1265
|
+
|
|
1284
1266
|
return wrapper
|
|
1285
1267
|
|
|
1286
1268
|
logger.info("Patch fsdp2 foreach_reduce, collect pre_grad metrics.")
|
|
@@ -1294,10 +1276,9 @@ class TrainerMon:
|
|
|
1294
1276
|
"""
|
|
1295
1277
|
遍历参数的梯度生成函数(grad_acc),并挂载hook,以便在该参数所有梯度计算后,采集通信聚合前梯度数据。
|
|
1296
1278
|
"""
|
|
1297
|
-
context = self.grad_context
|
|
1298
1279
|
|
|
1299
1280
|
@torch.no_grad
|
|
1300
|
-
def param_hook(*args,
|
|
1281
|
+
def param_hook(*args, param, name):
|
|
1301
1282
|
key = name
|
|
1302
1283
|
if self.monitor_mbs_grad:
|
|
1303
1284
|
key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
|
|
@@ -1305,14 +1286,15 @@ class TrainerMon:
|
|
|
1305
1286
|
key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
|
|
1306
1287
|
self.register_param_call_id("param_hook", key)
|
|
1307
1288
|
param.micro_step += 1
|
|
1308
|
-
|
|
1289
|
+
grad_dict = {}
|
|
1309
1290
|
if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
|
|
1310
1291
|
if self.params_have_main_grad:
|
|
1311
1292
|
grad = param.main_grad
|
|
1312
1293
|
else:
|
|
1313
1294
|
grad = param.grad
|
|
1314
|
-
|
|
1295
|
+
grad_dict[key] = grad.clone()
|
|
1315
1296
|
|
|
1297
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
1316
1298
|
if param.micro_step == self.micro_batch_number:
|
|
1317
1299
|
param.micro_step = 0
|
|
1318
1300
|
|
|
@@ -1322,7 +1304,7 @@ class TrainerMon:
|
|
|
1322
1304
|
param_tmp = param.expand_as(param)
|
|
1323
1305
|
grad_acc = param_tmp.grad_fn.next_functions[0][0]
|
|
1324
1306
|
handle = grad_acc.register_hook(
|
|
1325
|
-
partial(param_hook,
|
|
1307
|
+
partial(param_hook, param=param, name=name))
|
|
1326
1308
|
self.grad_accs.append(grad_acc)
|
|
1327
1309
|
self.handles['wgrads'].append(handle)
|
|
1328
1310
|
|
msprobe/pytorch/pt_config.py
CHANGED
|
@@ -80,6 +80,7 @@ class FreeBenchmarkCheckConfig(BaseConfig):
|
|
|
80
80
|
self.handler_type = json_config.get("handler_type", PytorchFreeBenchmarkConst.DEFAULT_HANDLER)
|
|
81
81
|
self.fuzz_level = json_config.get("fuzz_level", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_LEVEL)
|
|
82
82
|
self.fuzz_stage = json_config.get("fuzz_stage", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_STAGE)
|
|
83
|
+
self.list = json_config.get("list")
|
|
83
84
|
self.if_preheat = json_config.get("if_preheat", False)
|
|
84
85
|
self.preheat_step = json_config.get("preheat_step", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
|
|
85
86
|
self.max_sample = json_config.get("max_sample", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
|
|
@@ -146,6 +147,11 @@ class FreeBenchmarkCheckConfig(BaseConfig):
|
|
|
146
147
|
logger.error_log_with_exp(
|
|
147
148
|
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
148
149
|
)
|
|
150
|
+
if self.fuzz_stage == Const.BACKWARD and not self.list:
|
|
151
|
+
raise MsprobeException(
|
|
152
|
+
MsprobeException.INVALID_PARAM_ERROR,
|
|
153
|
+
f"When fuzz_stage is set to {Const.BACKWARD}, the parameters list must not be empty."
|
|
154
|
+
)
|
|
149
155
|
|
|
150
156
|
def _check_fuzz_level(self):
|
|
151
157
|
if self.fuzz_level not in PytorchFreeBenchmarkConst.FUZZ_LEVEL_LIST:
|
|
@@ -74,6 +74,7 @@ class GraphBuilder:
|
|
|
74
74
|
config.graph_b.data_source = GraphConst.JSON_BENCH_KEY
|
|
75
75
|
config.graph_b.step = config.step
|
|
76
76
|
config.graph_b.rank = config.rank
|
|
77
|
+
config.graph_b.compare_mode = config.compare_mode
|
|
77
78
|
node_to_db(config.graph_b, filename)
|
|
78
79
|
config_to_db(config, filename)
|
|
79
80
|
|
|
@@ -297,8 +298,8 @@ class GraphBuilder:
|
|
|
297
298
|
no_recompute_map = GraphBuilder._get_no_recompute_map(graph, id_prefixes)
|
|
298
299
|
if not no_recompute_map:
|
|
299
300
|
return
|
|
300
|
-
#
|
|
301
|
-
no_recompute_ids_b =
|
|
301
|
+
# 拷贝非重计算节点字典用于反向模式
|
|
302
|
+
no_recompute_ids_b = {node_id: list(node_list) for node_id, node_list in no_recompute_map.items()}
|
|
302
303
|
|
|
303
304
|
del_indexes = []
|
|
304
305
|
for node_id, id_prefix in recompute_map.items():
|
|
@@ -146,6 +146,7 @@ class BaseGraphMerger:
|
|
|
146
146
|
GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS,
|
|
147
147
|
id_accumulation=True)
|
|
148
148
|
all_collection_node = main_graph_result.graph.get_node(all_collection_node_id)
|
|
149
|
+
all_collection_node.upnode = main_graph_result.graph.root
|
|
149
150
|
new_main_root_sub_nodes.append(all_collection_node)
|
|
150
151
|
# Apis_Between_Modules.0 --> Apis_Between_Modules_Rank0.0
|
|
151
152
|
origin_main_node_id = main_node.id
|
|
@@ -377,6 +378,12 @@ class PPMerger(BaseGraphMerger):
|
|
|
377
378
|
logger.info('Unable to get pp groups based on Distributed Api (batch_isend_irecv, send, or isend), '
|
|
378
379
|
'generate pp groups using parallel param "rank_size", "tp" and "pp".')
|
|
379
380
|
_, pp_groups = self.get_default_groups()
|
|
381
|
+
elif len(pp_groups[0]) != self.parallel_param.pp:
|
|
382
|
+
logger.warning(f'Based on Distributed Api (atch_isend_irecv, send, or isend), '
|
|
383
|
+
f'the resulting pp groups={pp_groups}, '
|
|
384
|
+
f'its length is not equal to the parallel param "pp"({self.parallel_param.pp}) you defined, '
|
|
385
|
+
f'generate pp groups using parallel param "rank_size", "tp" and "pp".')
|
|
386
|
+
_, pp_groups = self.get_default_groups()
|
|
380
387
|
logger.info(f'{self.log_prefix} All pp groups is {pp_groups}.')
|
|
381
388
|
return pp_groups
|
|
382
389
|
|
|
@@ -657,6 +664,12 @@ class TPMerger(BaseGraphMerger):
|
|
|
657
664
|
logger.info('Unable to get tp groups based on Distributed Api (reduce_scatter or all_reduce), '
|
|
658
665
|
'generate tp groups using parallel param "rank_size", "tp" and "pp".')
|
|
659
666
|
tp_groups, _ = self.get_default_groups()
|
|
667
|
+
elif len(tp_groups[0]) != self.parallel_param.tp:
|
|
668
|
+
logger.warning(f'Based on Distributed Api (reduce_scatter or all_reduce), '
|
|
669
|
+
f'the resulting tp groups={tp_groups}, '
|
|
670
|
+
f'its length is not equal to the parallel param "tp"({self.parallel_param.tp}) you defined, '
|
|
671
|
+
f'generate tp groups using parallel param "rank_size", "tp" and "pp".')
|
|
672
|
+
tp_groups, _ = self.get_default_groups()
|
|
660
673
|
logger.info(f'{self.log_prefix} All tp groups is {tp_groups}.')
|
|
661
674
|
return tp_groups
|
|
662
675
|
|
|
@@ -126,21 +126,25 @@ class Graph:
|
|
|
126
126
|
|
|
127
127
|
def get_sorted_nodes(self):
|
|
128
128
|
"""
|
|
129
|
-
通过深度优先遍历graph,获得排过序的node
|
|
129
|
+
通过深度优先遍历graph,获得排过序的node列表,使用栈实现避免超出递归深度问题
|
|
130
130
|
"""
|
|
131
131
|
visited = set()
|
|
132
132
|
order = []
|
|
133
|
+
stack = [(self.root, False)]
|
|
133
134
|
|
|
134
|
-
|
|
135
|
-
|
|
135
|
+
while stack:
|
|
136
|
+
node, processed = stack.pop()
|
|
136
137
|
if node.id in visited:
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
138
|
+
continue
|
|
139
|
+
if processed:
|
|
140
|
+
visited.add(node.id)
|
|
141
|
+
order.append(node)
|
|
142
|
+
else:
|
|
143
|
+
stack.append((node, True))
|
|
144
|
+
for sub_node in reversed(node.subnodes):
|
|
145
|
+
if sub_node.id not in visited:
|
|
146
|
+
stack.append((sub_node, False))
|
|
142
147
|
|
|
143
|
-
visit(self.root)
|
|
144
148
|
return order
|
|
145
149
|
|
|
146
150
|
def add_node(self, node_op, node_id, up_node=None, id_accumulation=False):
|
msprobe/visualization/utils.py
CHANGED
|
@@ -152,7 +152,8 @@ def load_parallel_param(input_param):
|
|
|
152
152
|
|
|
153
153
|
|
|
154
154
|
def validate_parallel_param(parallel_param, dump_path, log_prefix='[NPU]'):
|
|
155
|
-
|
|
155
|
+
pattern = re.compile(r'^[a-z\-]+$')
|
|
156
|
+
params = [parallel_param.tp, parallel_param.pp, parallel_param.rank_size, parallel_param.vpp]
|
|
156
157
|
ranks = check_and_return_dir_contents(dump_path, Const.RANK)
|
|
157
158
|
if len(ranks) != parallel_param.rank_size:
|
|
158
159
|
logger.error(f'{log_prefix} The parallel param "rank_size" error, '
|
|
@@ -161,6 +162,12 @@ def validate_parallel_param(parallel_param, dump_path, log_prefix='[NPU]'):
|
|
|
161
162
|
if any(x is None for x in params):
|
|
162
163
|
logger.error(f'{log_prefix} The parallel params "tp/pp/rank_size" must not be null!')
|
|
163
164
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
165
|
+
if any(isinstance(x, bool) for x in params):
|
|
166
|
+
logger.error(f'{log_prefix} The parallel params "tp/pp/vpp/rank_size" must not be bool!')
|
|
167
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
168
|
+
if any(not isinstance(x, int) for x in params):
|
|
169
|
+
logger.error(f'{log_prefix} The parallel params "tp/pp/vpp/rank_size" must be int!')
|
|
170
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
164
171
|
if any(x <= 0 for x in params):
|
|
165
172
|
logger.error(f'{log_prefix} The parallel params "tp/pp/vpp/rank_size" must be greater than 0!')
|
|
166
173
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
@@ -185,6 +192,9 @@ def validate_parallel_param(parallel_param, dump_path, log_prefix='[NPU]'):
|
|
|
185
192
|
if not isinstance(parallel_param.order, str):
|
|
186
193
|
logger.error(f'{log_prefix} The parallel params "order" must be of string type!')
|
|
187
194
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
195
|
+
if not pattern.match(parallel_param.order):
|
|
196
|
+
logger.error(f'{log_prefix} The parallel params "order" must consist only of lowercase letters and "-"!')
|
|
197
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
188
198
|
|
|
189
199
|
|
|
190
200
|
class ParallelParam:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|