mindstudio-probe 8.2.1__py3-none-any.whl → 8.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/RECORD +39 -40
  3. msprobe/README.md +7 -2
  4. msprobe/core/common/const.py +17 -3
  5. msprobe/core/common/file_utils.py +138 -32
  6. msprobe/core/common/framework_adapter.py +16 -6
  7. msprobe/core/common/utils.py +17 -0
  8. msprobe/core/compare/diff_analyze/first_diff_analyze.py +4 -16
  9. msprobe/core/compare/find_first/utils.py +1 -1
  10. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +6 -1
  11. msprobe/core/hook_manager.py +0 -1
  12. msprobe/docs/01.installation.md +2 -0
  13. msprobe/docs/02.config_introduction.md +1 -1
  14. msprobe/docs/14.data_parse_PyTorch.md +2 -0
  15. msprobe/docs/15.free_benchmarking_PyTorch.md +1 -1
  16. msprobe/docs/21.visualization_PyTorch.md +1 -1
  17. msprobe/docs/26.data_dump_PyTorch_baseline.md +3 -3
  18. msprobe/docs/32.ckpt_compare.md +5 -5
  19. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  20. msprobe/mindspore/compare/utils.py +1 -2
  21. msprobe/mindspore/monitor/module_hook.py +17 -20
  22. msprobe/msprobe.py +6 -4
  23. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  24. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +34 -5
  25. msprobe/pytorch/common/utils.py +2 -52
  26. msprobe/pytorch/compare/utils.py +1 -2
  27. msprobe/pytorch/dump/module_dump/hook_wrapper.py +24 -0
  28. msprobe/pytorch/dump/module_dump/module_processer.py +27 -6
  29. msprobe/pytorch/hook_module/api_register.py +11 -2
  30. msprobe/pytorch/monitor/module_hook.py +16 -34
  31. msprobe/pytorch/pt_config.py +6 -0
  32. msprobe/visualization/builder/graph_builder.py +3 -2
  33. msprobe/visualization/builder/graph_merger.py +13 -0
  34. msprobe/visualization/graph/graph.py +13 -9
  35. msprobe/visualization/utils.py +11 -1
  36. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +0 -3
  37. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/LICENSE +0 -0
  38. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/WHEEL +0 -0
  39. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/entry_points.txt +0 -0
  40. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/top_level.txt +0 -0
@@ -13,21 +13,25 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import threading
17
16
  import sys
17
+ import threading
18
18
  from collections import OrderedDict
19
19
 
20
20
  import torch
21
21
  from torch.utils.hooks import BackwardHook, RemovableHandle
22
22
 
23
23
  from msprobe.core.common.const import Const
24
+ from msprobe.core.common.megatron_utils import wrap_megatron_step, get_micro_step, is_megatron
24
25
  from msprobe.core.common.runtime import Runtime
25
26
  from msprobe.core.common.utils import ModuleQueue, ThreadSafe
26
- from msprobe.core.common.megatron_utils import wrap_megatron_step, get_micro_step, is_megatron
27
27
  from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
28
28
  from msprobe.pytorch.common.log import logger
29
29
  from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook
30
- from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_input_output_hook
30
+ from msprobe.pytorch.dump.module_dump.hook_wrapper import (
31
+ wrap_setup_input_output_hook,
32
+ wrap_backward_hook_function_apply
33
+ )
34
+
31
35
 
32
36
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
33
37
  torch_version_above_or_equal_21 = torch.__version__.split('+')[0] >= '2.1'
@@ -59,10 +63,13 @@ def wrap_forward_with_hook_safety(module):
59
63
  except _StopRecomputationError as e:
60
64
  exception_output = None
61
65
  if len(module._forward_hooks.values()) > 0:
62
- # msprobe的forward_hook会出现在第一个,仅执行msprobe的forward_hook
63
- hook_fn = list(module._forward_hooks.values())[0]
64
- hook_fn(module, args, kwargs, exception_output)
66
+ # 仅执行msprobe的forward_hook, hook名称必然包含'ModuleProcesser.'
67
+ for hook_fn in module._forward_hooks.values():
68
+ if 'ModuleProcesser' in str(hook_fn):
69
+ hook_fn(module, args, kwargs, exception_output)
70
+ break
65
71
  raise e
72
+
66
73
  if torch_version_above_or_equal_21:
67
74
  module.forward = wrapped_forward
68
75
 
@@ -80,6 +87,7 @@ class ModuleProcesser:
80
87
  def __init__(self, scope):
81
88
  self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
82
89
  wrap_setup_input_output_hook()
90
+ wrap_backward_hook_function_apply()
83
91
  try:
84
92
  from megatron.core.pipeline_parallel import schedules
85
93
  origin_func_id = id(schedules.deallocate_output_tensor)
@@ -146,7 +154,13 @@ class ModuleProcesser:
146
154
  modules_and_names_with_index = self.get_modules_and_names(models, recursive, module_names)
147
155
  for index, modules_and_names in modules_and_names_with_index.items():
148
156
  model = models if index == "-1" else models[int(index)]
157
+
158
+ model_list = []
149
159
  for name, module in modules_and_names:
160
+ model_list.append((name, module))
161
+
162
+ is_verl = "verl" in sys.modules
163
+ for idx, (name, module) in enumerate(model_list):
150
164
  if recursive and module == model:
151
165
  continue
152
166
  if not is_torch_nn_module(module):
@@ -157,6 +171,13 @@ class ModuleProcesser:
157
171
  continue
158
172
  if module.__class__.__name__ == "FullyShardedDataParallel":
159
173
  continue
174
+
175
+ # verl 场景下跳过第一层和最后一层
176
+ if is_verl and (idx == 1 or idx == len(model_list) - 1):
177
+ logger.warning(f"The module {name} is the first or last layer in verl scenario, "
178
+ f"the data dump for this module will be skipped.")
179
+ continue
180
+
160
181
  setattr(module, 'msprobe_hook', True)
161
182
  module_index = (index + Const.SEP) if index != "-1" else ""
162
183
  prefix_name = f'{BaseScope.Module_Type_Module}{Const.SEP}{module_index}{name}{Const.SEP}' + \
@@ -43,7 +43,6 @@ else:
43
43
 
44
44
  torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
45
45
 
46
- _inner_used_api = {}
47
46
  _supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
48
47
  _cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
49
48
  dist_data_collect_func = {}
@@ -85,6 +84,12 @@ if not is_gpu:
85
84
  mindspeed_op_file_list = [op.split(Const.SEP)[0] + Const.PY_SUFFIX for op in mindspeed_op_list]
86
85
  dynamic_import_op(mindspeed.ops, mindspeed_op_file_list)
87
86
 
87
+ _inner_used_api = {
88
+ Const.PT_FRAMEWORK + Const.SEP + Const.PT_API_TYPE_TENSOR: (
89
+ torch.Tensor, "view_as"
90
+ )
91
+ }
92
+
88
93
 
89
94
  @parameter_adapter
90
95
  def tensor_module_forward(module, *args, **kwargs):
@@ -130,13 +135,17 @@ def redirect_wait():
130
135
  store_func = dist_data_collect_func.pop(args[0])
131
136
  store_func()
132
137
  return
138
+ remove_value = None
133
139
  for value in dist_batch_data_collect_func:
134
140
  if args[0] in value[0]:
135
141
  value[0].remove(args[0])
136
142
  if len(value[0]) == 0:
137
143
  store_func = value[1]
138
144
  store_func()
139
- return
145
+ remove_value = value
146
+ break
147
+ if remove_value:
148
+ dist_batch_data_collect_func.remove(remove_value)
140
149
 
141
150
  return wrapped_wait
142
151
 
@@ -48,12 +48,10 @@ from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_write
48
48
  from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
49
49
  from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
50
50
 
51
-
52
51
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
53
52
  if not torch_version_above_or_equal_2:
54
53
  raise ValueError("monitor require torch>=2.0")
55
54
 
56
-
57
55
  FORMAT_MAPPING = {
58
56
  MonitorConst.TENSORBOARD: SummaryWriterWithAD,
59
57
  MonitorConst.CSV: CSVWriterWithAD,
@@ -150,15 +148,11 @@ class GradContext:
150
148
  def __init__(self) -> None:
151
149
  self.pre = {}
152
150
  self.post = {}
153
- self.acc_metric = {}
154
- self.acc = {}
155
151
  self.actv = {}
156
152
 
157
153
  def reset(self):
158
154
  self.pre.clear()
159
155
  self.post.clear()
160
- self.acc_metric.clear()
161
- self.acc.clear()
162
156
  self.actv.clear()
163
157
 
164
158
 
@@ -510,18 +504,8 @@ class TrainerMon:
510
504
  if not self.wg_distribution:
511
505
  return {}, {}
512
506
 
513
- if self.weight_hooked:
514
- get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
515
-
516
507
  get_metrics(self.ops, post_grad_dict, self.eps, self.grad_context.post)
517
- reduced_grad = self.grad_context.post
518
-
519
- if self.weight_hooked:
520
- unreduced_grad = self.grad_context.acc_metric
521
- else:
522
- unreduced_grad = self.grad_context.pre
523
-
524
- return reduced_grad, unreduced_grad
508
+ return self.grad_context.post, self.grad_context.pre
525
509
 
526
510
  def generate_xy_metrics(self):
527
511
  actv = {}
@@ -529,7 +513,6 @@ class TrainerMon:
529
513
  actv.update(fwd_context.actv)
530
514
 
531
515
  actv_grad = self.grad_context.actv
532
-
533
516
  return actv, actv_grad
534
517
 
535
518
  def reload_xy(self, xy_distribution=False):
@@ -607,11 +590,8 @@ class TrainerMon:
607
590
  if not self.wg_distribution:
608
591
  return
609
592
 
610
- if self.weight_hooked:
611
- self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced',
612
- use_micro_step=self.monitor_mbs_grad)
613
- else:
614
- self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
593
+ self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced',
594
+ use_micro_step=self.monitor_mbs_grad)
615
595
  self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
616
596
 
617
597
  def hook_optimizer(self, optimizer):
@@ -732,9 +712,9 @@ class TrainerMon:
732
712
  # 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
733
713
  if self.monitoring:
734
714
  module_rank_valid = not self.module_rank_list or (
735
- dist.is_initialized() and dist.get_rank() in self.module_rank_list)
715
+ dist.is_initialized() and dist.get_rank() in self.module_rank_list)
736
716
  step_condition = (context.step >= self.start_step and (
737
- context.step - self.start_step) % self.step_interval == 0)
717
+ context.step - self.start_step) % self.step_interval == 0)
738
718
  if module_rank_valid and step_condition:
739
719
  self.has_collect_times += 1
740
720
 
@@ -791,6 +771,7 @@ class TrainerMon:
791
771
  hook(optimizer, args, kwargs)
792
772
  step_final_hook(optimizer, args, kwargs)
793
773
  return out
774
+
794
775
  return wrapper
795
776
 
796
777
  optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
@@ -1013,11 +994,11 @@ class TrainerMon:
1013
994
  vpp_stage + module_name,
1014
995
  ]:
1015
996
  if pattern in l2_targets:
1016
- return pattern
997
+ return pattern
1017
998
  elif hook_name in ["linear_hook"]:
1018
999
  return vpp_stage + squash_param_name(module_name, self.squash_name)
1019
1000
  return ""
1020
-
1001
+
1021
1002
  def _hook_module(self, target_names, l2_target_names, module: torch.nn.Module, vpp_stage=''):
1022
1003
  if '_modules' not in module.__dict__:
1023
1004
  # nothing to hook
@@ -1151,7 +1132,7 @@ class TrainerMon:
1151
1132
  context.micro_step = 0
1152
1133
  context.step += 1
1153
1134
  return
1154
-
1135
+
1155
1136
  def stack_hook(module, args, kwargs, module_output, name):
1156
1137
  if module not in self.module_fwd_hook_context_by_module:
1157
1138
  self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
@@ -1221,7 +1202,7 @@ class TrainerMon:
1221
1202
  if self.monitor_mbs_grad:
1222
1203
  self._hook_weights()
1223
1204
  return
1224
-
1205
+
1225
1206
  self.optimizer_mon.patch_grad_sync(self)
1226
1207
 
1227
1208
  if self.enable_megatron or self.enable_deepspeed:
@@ -1281,6 +1262,7 @@ class TrainerMon:
1281
1262
  get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
1282
1263
  out = foreach_reduce(fsdp_params, unsharded_grads, *unused)
1283
1264
  return out
1265
+
1284
1266
  return wrapper
1285
1267
 
1286
1268
  logger.info("Patch fsdp2 foreach_reduce, collect pre_grad metrics.")
@@ -1294,10 +1276,9 @@ class TrainerMon:
1294
1276
  """
1295
1277
  遍历参数的梯度生成函数(grad_acc),并挂载hook,以便在该参数所有梯度计算后,采集通信聚合前梯度数据。
1296
1278
  """
1297
- context = self.grad_context
1298
1279
 
1299
1280
  @torch.no_grad
1300
- def param_hook(*args, context_dict, param, name):
1281
+ def param_hook(*args, param, name):
1301
1282
  key = name
1302
1283
  if self.monitor_mbs_grad:
1303
1284
  key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
@@ -1305,14 +1286,15 @@ class TrainerMon:
1305
1286
  key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
1306
1287
  self.register_param_call_id("param_hook", key)
1307
1288
  param.micro_step += 1
1308
-
1289
+ grad_dict = {}
1309
1290
  if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
1310
1291
  if self.params_have_main_grad:
1311
1292
  grad = param.main_grad
1312
1293
  else:
1313
1294
  grad = param.grad
1314
- context_dict[key] = grad.clone()
1295
+ grad_dict[key] = grad.clone()
1315
1296
 
1297
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
1316
1298
  if param.micro_step == self.micro_batch_number:
1317
1299
  param.micro_step = 0
1318
1300
 
@@ -1322,7 +1304,7 @@ class TrainerMon:
1322
1304
  param_tmp = param.expand_as(param)
1323
1305
  grad_acc = param_tmp.grad_fn.next_functions[0][0]
1324
1306
  handle = grad_acc.register_hook(
1325
- partial(param_hook, context_dict=context.acc, param=param, name=name))
1307
+ partial(param_hook, param=param, name=name))
1326
1308
  self.grad_accs.append(grad_acc)
1327
1309
  self.handles['wgrads'].append(handle)
1328
1310
 
@@ -80,6 +80,7 @@ class FreeBenchmarkCheckConfig(BaseConfig):
80
80
  self.handler_type = json_config.get("handler_type", PytorchFreeBenchmarkConst.DEFAULT_HANDLER)
81
81
  self.fuzz_level = json_config.get("fuzz_level", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_LEVEL)
82
82
  self.fuzz_stage = json_config.get("fuzz_stage", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_STAGE)
83
+ self.list = json_config.get("list")
83
84
  self.if_preheat = json_config.get("if_preheat", False)
84
85
  self.preheat_step = json_config.get("preheat_step", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
85
86
  self.max_sample = json_config.get("max_sample", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
@@ -146,6 +147,11 @@ class FreeBenchmarkCheckConfig(BaseConfig):
146
147
  logger.error_log_with_exp(
147
148
  msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
148
149
  )
150
+ if self.fuzz_stage == Const.BACKWARD and not self.list:
151
+ raise MsprobeException(
152
+ MsprobeException.INVALID_PARAM_ERROR,
153
+ f"When fuzz_stage is set to {Const.BACKWARD}, the parameters list must not be empty."
154
+ )
149
155
 
150
156
  def _check_fuzz_level(self):
151
157
  if self.fuzz_level not in PytorchFreeBenchmarkConst.FUZZ_LEVEL_LIST:
@@ -74,6 +74,7 @@ class GraphBuilder:
74
74
  config.graph_b.data_source = GraphConst.JSON_BENCH_KEY
75
75
  config.graph_b.step = config.step
76
76
  config.graph_b.rank = config.rank
77
+ config.graph_b.compare_mode = config.compare_mode
77
78
  node_to_db(config.graph_b, filename)
78
79
  config_to_db(config, filename)
79
80
 
@@ -297,8 +298,8 @@ class GraphBuilder:
297
298
  no_recompute_map = GraphBuilder._get_no_recompute_map(graph, id_prefixes)
298
299
  if not no_recompute_map:
299
300
  return
300
- # 深拷贝非重计算节点字典用于反向模式
301
- no_recompute_ids_b = copy.deepcopy(no_recompute_map)
301
+ # 拷贝非重计算节点字典用于反向模式
302
+ no_recompute_ids_b = {node_id: list(node_list) for node_id, node_list in no_recompute_map.items()}
302
303
 
303
304
  del_indexes = []
304
305
  for node_id, id_prefix in recompute_map.items():
@@ -146,6 +146,7 @@ class BaseGraphMerger:
146
146
  GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS,
147
147
  id_accumulation=True)
148
148
  all_collection_node = main_graph_result.graph.get_node(all_collection_node_id)
149
+ all_collection_node.upnode = main_graph_result.graph.root
149
150
  new_main_root_sub_nodes.append(all_collection_node)
150
151
  # Apis_Between_Modules.0 --> Apis_Between_Modules_Rank0.0
151
152
  origin_main_node_id = main_node.id
@@ -377,6 +378,12 @@ class PPMerger(BaseGraphMerger):
377
378
  logger.info('Unable to get pp groups based on Distributed Api (batch_isend_irecv, send, or isend), '
378
379
  'generate pp groups using parallel param "rank_size", "tp" and "pp".')
379
380
  _, pp_groups = self.get_default_groups()
381
+ elif len(pp_groups[0]) != self.parallel_param.pp:
382
+ logger.warning(f'Based on Distributed Api (atch_isend_irecv, send, or isend), '
383
+ f'the resulting pp groups={pp_groups}, '
384
+ f'its length is not equal to the parallel param "pp"({self.parallel_param.pp}) you defined, '
385
+ f'generate pp groups using parallel param "rank_size", "tp" and "pp".')
386
+ _, pp_groups = self.get_default_groups()
380
387
  logger.info(f'{self.log_prefix} All pp groups is {pp_groups}.')
381
388
  return pp_groups
382
389
 
@@ -657,6 +664,12 @@ class TPMerger(BaseGraphMerger):
657
664
  logger.info('Unable to get tp groups based on Distributed Api (reduce_scatter or all_reduce), '
658
665
  'generate tp groups using parallel param "rank_size", "tp" and "pp".')
659
666
  tp_groups, _ = self.get_default_groups()
667
+ elif len(tp_groups[0]) != self.parallel_param.tp:
668
+ logger.warning(f'Based on Distributed Api (reduce_scatter or all_reduce), '
669
+ f'the resulting tp groups={tp_groups}, '
670
+ f'its length is not equal to the parallel param "tp"({self.parallel_param.tp}) you defined, '
671
+ f'generate tp groups using parallel param "rank_size", "tp" and "pp".')
672
+ tp_groups, _ = self.get_default_groups()
660
673
  logger.info(f'{self.log_prefix} All tp groups is {tp_groups}.')
661
674
  return tp_groups
662
675
 
@@ -126,21 +126,25 @@ class Graph:
126
126
 
127
127
  def get_sorted_nodes(self):
128
128
  """
129
- 通过深度优先遍历graph,获得排过序的node列表
129
+ 通过深度优先遍历graph,获得排过序的node列表,使用栈实现避免超出递归深度问题
130
130
  """
131
131
  visited = set()
132
132
  order = []
133
+ stack = [(self.root, False)]
133
134
 
134
- @recursion_depth_decorator('msprobe.visualization.graph.graph.Graph.get_nodes_order.visit', max_depth=500)
135
- def visit(node):
135
+ while stack:
136
+ node, processed = stack.pop()
136
137
  if node.id in visited:
137
- return
138
- visited.add(node.id)
139
- for sub_node in node.subnodes:
140
- visit(sub_node)
141
- order.append(node)
138
+ continue
139
+ if processed:
140
+ visited.add(node.id)
141
+ order.append(node)
142
+ else:
143
+ stack.append((node, True))
144
+ for sub_node in reversed(node.subnodes):
145
+ if sub_node.id not in visited:
146
+ stack.append((sub_node, False))
142
147
 
143
- visit(self.root)
144
148
  return order
145
149
 
146
150
  def add_node(self, node_op, node_id, up_node=None, id_accumulation=False):
@@ -152,7 +152,8 @@ def load_parallel_param(input_param):
152
152
 
153
153
 
154
154
  def validate_parallel_param(parallel_param, dump_path, log_prefix='[NPU]'):
155
- params = [parallel_param.tp, parallel_param.pp, parallel_param.rank_size]
155
+ pattern = re.compile(r'^[a-z\-]+$')
156
+ params = [parallel_param.tp, parallel_param.pp, parallel_param.rank_size, parallel_param.vpp]
156
157
  ranks = check_and_return_dir_contents(dump_path, Const.RANK)
157
158
  if len(ranks) != parallel_param.rank_size:
158
159
  logger.error(f'{log_prefix} The parallel param "rank_size" error, '
@@ -161,6 +162,12 @@ def validate_parallel_param(parallel_param, dump_path, log_prefix='[NPU]'):
161
162
  if any(x is None for x in params):
162
163
  logger.error(f'{log_prefix} The parallel params "tp/pp/rank_size" must not be null!')
163
164
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
165
+ if any(isinstance(x, bool) for x in params):
166
+ logger.error(f'{log_prefix} The parallel params "tp/pp/vpp/rank_size" must not be bool!')
167
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
168
+ if any(not isinstance(x, int) for x in params):
169
+ logger.error(f'{log_prefix} The parallel params "tp/pp/vpp/rank_size" must be int!')
170
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
164
171
  if any(x <= 0 for x in params):
165
172
  logger.error(f'{log_prefix} The parallel params "tp/pp/vpp/rank_size" must be greater than 0!')
166
173
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
@@ -185,6 +192,9 @@ def validate_parallel_param(parallel_param, dump_path, log_prefix='[NPU]'):
185
192
  if not isinstance(parallel_param.order, str):
186
193
  logger.error(f'{log_prefix} The parallel params "order" must be of string type!')
187
194
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
195
+ if not pattern.match(parallel_param.order):
196
+ logger.error(f'{log_prefix} The parallel params "order" must consist only of lowercase letters and "-"!')
197
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
188
198
 
189
199
 
190
200
  class ParallelParam:
@@ -1,3 +0,0 @@
1
- npu_fusion_attention:
2
- - 4
3
- - 5