mindstudio-probe 8.2.1__py3-none-any.whl → 8.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/RECORD +46 -37
  3. msprobe/README.md +3 -1
  4. msprobe/core/common/file_utils.py +80 -25
  5. msprobe/core/common/framework_adapter.py +7 -6
  6. msprobe/core/compare/diff_analyze/first_diff_analyze.py +4 -16
  7. msprobe/core/compare/find_first/utils.py +1 -1
  8. msprobe/core/hook_manager.py +16 -3
  9. msprobe/core/service.py +16 -5
  10. msprobe/docs/02.config_introduction.md +14 -1
  11. msprobe/docs/08.accuracy_checker_online_PyTorch.md +295 -0
  12. msprobe/docs/15.free_benchmarking_PyTorch.md +1 -1
  13. msprobe/docs/25.tool_function_introduction.md +1 -0
  14. msprobe/docs/26.data_dump_PyTorch_baseline.md +3 -3
  15. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  16. msprobe/mindspore/compare/utils.py +1 -2
  17. msprobe/msprobe.py +6 -4
  18. msprobe/pytorch/api_accuracy_checker/common/config.py +36 -3
  19. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +24 -0
  20. msprobe/pytorch/api_accuracy_checker/compare/compare.py +12 -2
  21. msprobe/pytorch/api_accuracy_checker/config.yaml +6 -1
  22. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  23. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +132 -12
  24. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  25. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +205 -0
  26. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +378 -0
  27. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +239 -0
  28. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  29. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +250 -0
  30. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  31. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +198 -0
  32. msprobe/pytorch/attl_manager.py +65 -0
  33. msprobe/pytorch/common/utils.py +22 -2
  34. msprobe/pytorch/compare/utils.py +1 -2
  35. msprobe/pytorch/debugger/debugger_config.py +10 -0
  36. msprobe/pytorch/dump/module_dump/hook_wrapper.py +24 -0
  37. msprobe/pytorch/dump/module_dump/module_processer.py +9 -3
  38. msprobe/pytorch/hook_module/api_register.py +6 -1
  39. msprobe/pytorch/pt_config.py +57 -2
  40. msprobe/pytorch/pytorch_service.py +11 -2
  41. msprobe/visualization/builder/graph_builder.py +1 -0
  42. msprobe/visualization/utils.py +11 -1
  43. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +0 -3
  44. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/LICENSE +0 -0
  45. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/WHEEL +0 -0
  46. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/entry_points.txt +0 -0
  47. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/top_level.txt +0 -0
@@ -39,7 +39,6 @@ except ImportError:
39
39
  else:
40
40
  is_gpu = False
41
41
 
42
-
43
42
  torch_without_guard_version = torch.__version__ >= '2.1'
44
43
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
45
44
 
@@ -416,7 +415,8 @@ def is_recomputation():
416
415
 
417
416
  # Identify indices in the call stack where the specific function is being executed
418
417
  for idx, frame_info in enumerate(call_stack):
419
- if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward':
418
+ if (frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward' and
419
+ "megatron" in frame_info.filename):
420
420
  backward_function_indices.append(idx)
421
421
 
422
422
  # Check if the execution is within 'torch/autograd/function.py' file
@@ -471,3 +471,23 @@ def register_forward_hook(module, forward_hook):
471
471
  module.register_forward_hook(forward_hook, with_kwargs=True)
472
472
  else:
473
473
  module.register_forward_hook(forward_hook)
474
+
475
+
476
+ def save_api_data(api_data):
477
+ """Save data to io stream"""
478
+ try:
479
+ io_buff = io.BytesIO()
480
+ torch.save(api_data, io_buff)
481
+ except Exception as e:
482
+ raise RuntimeError(f"save api_data to io_buff failed") from e
483
+ return io_buff
484
+
485
+
486
+ def load_api_data(api_data_bytes):
487
+ """Load data from bytes stream"""
488
+ try:
489
+ buffer = io.BytesIO(api_data_bytes)
490
+ buffer = torch.load(buffer, map_location="cpu")
491
+ except Exception as e:
492
+ raise RuntimeError(f"load api_data from bytes failed") from e
493
+ return buffer
@@ -27,8 +27,7 @@ def read_pt_data(dir_path, file_name):
27
27
  return None
28
28
 
29
29
  data_path = os.path.join(dir_path, file_name)
30
- path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
31
- FileCheckConst.PT_SUFFIX, False)
30
+ path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.PT_SUFFIX)
32
31
  data_path = path_checker.common_check()
33
32
  try:
34
33
  # detach because numpy can not process gradient information
@@ -48,6 +48,16 @@ class DebuggerConfig:
48
48
  "max_sample": task_config.max_sample
49
49
  }
50
50
 
51
+ self.online_run_ut = False
52
+ if self.task == Const.TENSOR:
53
+ # dump api tensor and collaborate with online run_ut
54
+ self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
55
+ self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
56
+ self.tls_path = task_config.tls_path if task_config.tls_path else ""
57
+ self.host = task_config.host if task_config.host else ""
58
+ self.port = task_config.port if task_config.port else -1
59
+ self.online_run_ut_recompute = task_config.online_run_ut_recompute \
60
+ if isinstance(task_config.online_run_ut_recompute, bool) else False
51
61
 
52
62
  self.check()
53
63
  self._check_statistics_config(task_config)
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from functools import wraps
17
+ from typing import Any, Callable
17
18
 
18
19
  import torch
19
20
  from torch.utils.hooks import BackwardHook
@@ -21,6 +22,9 @@ from torch.utils.hooks import BackwardHook
21
22
  from msprobe.core.common.const import Const
22
23
  from msprobe.core.common.decorator import recursion_depth_decorator
23
24
  from msprobe.pytorch.common.log import logger
25
+ from msprobe.pytorch.hook_module.api_register import get_api_register
26
+
27
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
24
28
 
25
29
 
26
30
  def wrap_setup_backward_hook(func):
@@ -92,3 +96,23 @@ def wrap_setup_backward_hook(func):
92
96
  def wrap_setup_input_output_hook():
93
97
  BackwardHook.setup_input_hook = wrap_setup_backward_hook(BackwardHook.setup_input_hook)
94
98
  BackwardHook.setup_output_hook = wrap_setup_backward_hook(BackwardHook.setup_output_hook)
99
+
100
+
101
+ def get_apply_func_wrapper(original_func: Callable) -> Callable:
102
+ @wraps(original_func)
103
+ def wrapped_apply(*args, **kwargs) -> Any:
104
+ api_register = get_api_register()
105
+ if api_register:
106
+ api_register.restore_inner_used_api()
107
+ result = original_func(*args, **kwargs)
108
+ if api_register:
109
+ api_register.register_inner_used_api()
110
+ return result
111
+
112
+ return wrapped_apply
113
+
114
+
115
+ def wrap_backward_hook_function_apply():
116
+ if torch_version_above_or_equal_2:
117
+ original_apply = torch.nn.modules._functions.BackwardHookFunction.apply
118
+ torch.nn.modules._functions.BackwardHookFunction.apply = get_apply_func_wrapper(original_apply)
@@ -13,21 +13,25 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import threading
17
16
  import sys
17
+ import threading
18
18
  from collections import OrderedDict
19
19
 
20
20
  import torch
21
21
  from torch.utils.hooks import BackwardHook, RemovableHandle
22
22
 
23
23
  from msprobe.core.common.const import Const
24
+ from msprobe.core.common.megatron_utils import wrap_megatron_step, get_micro_step, is_megatron
24
25
  from msprobe.core.common.runtime import Runtime
25
26
  from msprobe.core.common.utils import ModuleQueue, ThreadSafe
26
- from msprobe.core.common.megatron_utils import wrap_megatron_step, get_micro_step, is_megatron
27
27
  from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
28
28
  from msprobe.pytorch.common.log import logger
29
29
  from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook
30
- from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_input_output_hook
30
+ from msprobe.pytorch.dump.module_dump.hook_wrapper import (
31
+ wrap_setup_input_output_hook,
32
+ wrap_backward_hook_function_apply
33
+ )
34
+
31
35
 
32
36
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
33
37
  torch_version_above_or_equal_21 = torch.__version__.split('+')[0] >= '2.1'
@@ -63,6 +67,7 @@ def wrap_forward_with_hook_safety(module):
63
67
  hook_fn = list(module._forward_hooks.values())[0]
64
68
  hook_fn(module, args, kwargs, exception_output)
65
69
  raise e
70
+
66
71
  if torch_version_above_or_equal_21:
67
72
  module.forward = wrapped_forward
68
73
 
@@ -80,6 +85,7 @@ class ModuleProcesser:
80
85
  def __init__(self, scope):
81
86
  self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
82
87
  wrap_setup_input_output_hook()
88
+ wrap_backward_hook_function_apply()
83
89
  try:
84
90
  from megatron.core.pipeline_parallel import schedules
85
91
  origin_func_id = id(schedules.deallocate_output_tensor)
@@ -43,7 +43,6 @@ else:
43
43
 
44
44
  torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
45
45
 
46
- _inner_used_api = {}
47
46
  _supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
48
47
  _cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
49
48
  dist_data_collect_func = {}
@@ -85,6 +84,12 @@ if not is_gpu:
85
84
  mindspeed_op_file_list = [op.split(Const.SEP)[0] + Const.PY_SUFFIX for op in mindspeed_op_list]
86
85
  dynamic_import_op(mindspeed.ops, mindspeed_op_file_list)
87
86
 
87
+ _inner_used_api = {
88
+ Const.PT_FRAMEWORK + Const.SEP + Const.PT_API_TYPE_TENSOR: (
89
+ torch.Tensor, "view_as"
90
+ )
91
+ }
92
+
88
93
 
89
94
  @parameter_adapter
90
95
  def tensor_module_forward(module, *args, **kwargs):
@@ -35,15 +35,48 @@ from msprobe.pytorch.hook_module.utils import get_ops
35
35
  class TensorConfig(BaseConfig):
36
36
  def __init__(self, json_config):
37
37
  super().__init__(json_config)
38
+ self.online_run_ut = json_config.get("online_run_ut", False)
39
+ self.nfs_path = json_config.get("nfs_path", "")
40
+ self.host = json_config.get("host", "")
41
+ self.port = json_config.get("port", -1)
42
+ self.tls_path = json_config.get("tls_path", "./")
43
+ self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
38
44
  self.check_config()
39
45
  self._check_summary_mode()
40
46
  self._check_file_format()
41
-
47
+ if self.online_run_ut:
48
+ self._check_online_run_ut()
42
49
 
43
50
  def _check_file_format(self):
44
51
  if self.file_format is not None and self.file_format not in ["npy", "bin"]:
45
52
  raise Exception("file_format is invalid")
46
53
 
54
+ def _check_online_run_ut(self):
55
+ if not isinstance(self.online_run_ut, bool):
56
+ raise Exception(f"online_run_ut: {self.online_run_ut} is invalid.")
57
+
58
+ if not isinstance(self.online_run_ut_recompute, bool):
59
+ raise Exception(f"online_run_ut_recompute: {self.online_run_ut_recompute} is invalid.")
60
+
61
+ if self.nfs_path:
62
+ check_file_or_directory_path(self.nfs_path, isdir=True)
63
+ return
64
+
65
+ if self.tls_path:
66
+ check_file_or_directory_path(self.tls_path, isdir=True)
67
+ check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
68
+ check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
69
+ check_file_or_directory_path(os.path.join(self.tls_path, "ca.crt"))
70
+ crl_path = os.path.join(self.tls_path, "crl.pem")
71
+ if os.path.exists(crl_path):
72
+ check_file_or_directory_path(crl_path)
73
+
74
+ if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
75
+ raise Exception(f"host: {self.host} is invalid.")
76
+
77
+ if not isinstance(self.port, int) or not (0 < self.port <= 65535):
78
+ raise Exception(f"port: {self.port} is invalid, port range 0-65535.")
79
+
47
80
 
48
81
  class StatisticsConfig(BaseConfig):
49
82
  def __init__(self, json_config):
@@ -80,6 +113,7 @@ class FreeBenchmarkCheckConfig(BaseConfig):
80
113
  self.handler_type = json_config.get("handler_type", PytorchFreeBenchmarkConst.DEFAULT_HANDLER)
81
114
  self.fuzz_level = json_config.get("fuzz_level", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_LEVEL)
82
115
  self.fuzz_stage = json_config.get("fuzz_stage", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_STAGE)
116
+ self.list = json_config.get("list")
83
117
  self.if_preheat = json_config.get("if_preheat", False)
84
118
  self.preheat_step = json_config.get("preheat_step", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
85
119
  self.max_sample = json_config.get("max_sample", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
@@ -146,6 +180,11 @@ class FreeBenchmarkCheckConfig(BaseConfig):
146
180
  logger.error_log_with_exp(
147
181
  msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
148
182
  )
183
+ if self.fuzz_stage == Const.BACKWARD and not self.list:
184
+ raise MsprobeException(
185
+ MsprobeException.INVALID_PARAM_ERROR,
186
+ f"When fuzz_stage is set to {Const.BACKWARD}, the parameters list must not be empty."
187
+ )
149
188
 
150
189
  def _check_fuzz_level(self):
151
190
  if self.fuzz_level not in PytorchFreeBenchmarkConst.FUZZ_LEVEL_LIST:
@@ -218,7 +257,12 @@ class RunUTConfig(BaseConfig):
218
257
  self.white_list = json_config.get("white_list", Const.DEFAULT_LIST)
219
258
  self.black_list = json_config.get("black_list", Const.DEFAULT_LIST)
220
259
  self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH)
221
-
260
+ self.is_online = json_config.get("is_online", False)
261
+ self.nfs_path = json_config.get("nfs_path", "")
262
+ self.host = json_config.get("host", "")
263
+ self.port = json_config.get("port", -1)
264
+ self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
265
+ self.tls_path = json_config.get("tls_path", "./")
222
266
  self.check_run_ut_config()
223
267
 
224
268
  @classmethod
@@ -236,11 +280,22 @@ class RunUTConfig(BaseConfig):
236
280
  if not os.path.exists(error_data_path):
237
281
  raise Exception("error_data_path: %s does not exist" % error_data_path)
238
282
 
283
+ @classmethod
284
+ def check_nfs_path_config(cls, nfs_path):
285
+ if nfs_path:
286
+ FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
287
+
288
+ @classmethod
289
+ def check_tls_path_config(cls, tls_path):
290
+ if tls_path:
291
+ FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
239
292
 
240
293
  def check_run_ut_config(self):
241
294
  RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
242
295
  RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list)
243
296
  RunUTConfig.check_error_data_path_config(self.error_data_path)
297
+ RunUTConfig.check_nfs_path_config(self.nfs_path)
298
+ RunUTConfig.check_tls_path_config(self.tls_path)
244
299
 
245
300
 
246
301
  class GradToolConfig(BaseConfig):
@@ -15,8 +15,9 @@
15
15
 
16
16
  from msprobe.core.common.utils import Const
17
17
  from msprobe.core.service import BaseService
18
+ from msprobe.pytorch.attl_manager import ATTLManager
18
19
  from msprobe.pytorch.common.log import logger
19
- from msprobe.pytorch.common.utils import get_rank_if_initialized
20
+ from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_above_or_equal_2
20
21
  from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
21
22
  from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate, redirect_wait
22
23
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
@@ -24,6 +25,9 @@ from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager
24
25
  from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
25
26
  from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, preprocess_func
26
27
 
28
+ if torch_version_above_or_equal_2:
29
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
30
+
27
31
 
28
32
  class PytorchService(BaseService):
29
33
  @property
@@ -41,10 +45,12 @@ class PytorchService(BaseService):
41
45
  self.logger = logger
42
46
  self.api_register = get_api_register()
43
47
  self.module_processor = ModuleProcesser(self.data_collector.scope)
44
- self.hook_manager = PytorchHookManager(self.data_collector, self.config)
48
+ self.attl_manager = ATTLManager(self.config)
49
+ self.hook_manager = PytorchHookManager(self.data_collector, self.config, self.attl_manager)
45
50
  self.api_template = ApiTemplate
46
51
 
47
52
  def _register_hook(self):
53
+ self.attl_manager.attl_init()
48
54
  if self._is_mix_level:
49
55
  register_optimizer_hook(self.data_collector)
50
56
 
@@ -59,6 +65,9 @@ class PytorchService(BaseService):
59
65
  self.module_processor.register_module_hook(self.model, self.build_hook)
60
66
  self.logger.info(f"The module {self.config.task} hook function is successfully mounted to the model.")
61
67
 
68
+ def _run_ut_dispatch(self, status):
69
+ if torch_version_above_or_equal_2:
70
+ run_ut_dispatch(self.attl_manager.attl, status, self.config.online_run_ut_recompute)
62
71
 
63
72
  def _reset_status(self):
64
73
  super()._reset_status()
@@ -74,6 +74,7 @@ class GraphBuilder:
74
74
  config.graph_b.data_source = GraphConst.JSON_BENCH_KEY
75
75
  config.graph_b.step = config.step
76
76
  config.graph_b.rank = config.rank
77
+ config.graph_b.compare_mode = config.compare_mode
77
78
  node_to_db(config.graph_b, filename)
78
79
  config_to_db(config, filename)
79
80
 
@@ -152,7 +152,8 @@ def load_parallel_param(input_param):
152
152
 
153
153
 
154
154
  def validate_parallel_param(parallel_param, dump_path, log_prefix='[NPU]'):
155
- params = [parallel_param.tp, parallel_param.pp, parallel_param.rank_size]
155
+ pattern = re.compile(r'^[a-z\-]+$')
156
+ params = [parallel_param.tp, parallel_param.pp, parallel_param.rank_size, parallel_param.vpp]
156
157
  ranks = check_and_return_dir_contents(dump_path, Const.RANK)
157
158
  if len(ranks) != parallel_param.rank_size:
158
159
  logger.error(f'{log_prefix} The parallel param "rank_size" error, '
@@ -161,6 +162,12 @@ def validate_parallel_param(parallel_param, dump_path, log_prefix='[NPU]'):
161
162
  if any(x is None for x in params):
162
163
  logger.error(f'{log_prefix} The parallel params "tp/pp/rank_size" must not be null!')
163
164
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
165
+ if any(isinstance(x, bool) for x in params):
166
+ logger.error(f'{log_prefix} The parallel params "tp/pp/vpp/rank_size" must not be bool!')
167
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
168
+ if any(not isinstance(x, int) for x in params):
169
+ logger.error(f'{log_prefix} The parallel params "tp/pp/vpp/rank_size" must be int!')
170
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
164
171
  if any(x <= 0 for x in params):
165
172
  logger.error(f'{log_prefix} The parallel params "tp/pp/vpp/rank_size" must be greater than 0!')
166
173
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
@@ -185,6 +192,9 @@ def validate_parallel_param(parallel_param, dump_path, log_prefix='[NPU]'):
185
192
  if not isinstance(parallel_param.order, str):
186
193
  logger.error(f'{log_prefix} The parallel params "order" must be of string type!')
187
194
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
195
+ if not pattern.match(parallel_param.order):
196
+ logger.error(f'{log_prefix} The parallel params "order" must consist only of lowercase letters and "-"!')
197
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
188
198
 
189
199
 
190
200
  class ParallelParam:
@@ -1,3 +0,0 @@
1
- npu_fusion_attention:
2
- - 4
3
- - 5