mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -0,0 +1,44 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import numpy as np
18
+
19
+
20
+ def softmax_func(x, axis=None):
21
+ x = x.float()
22
+ x_max = x.max(dim=axis, keepdims=True).values
23
+ x_sub = x - x_max
24
+ y = torch.exp(x_sub)
25
+ x_sum = y.sum(dim=axis, keepdims=True)
26
+ ans = 0 if (x_sum == 0).any() else y / x_sum
27
+ return ans
28
+
29
+
30
+ def npu_moe_gating_top_k_softmax(x, finished_optional, k):
31
+ input_dtype = x.dtype
32
+ num_expert = x.shape[-1]
33
+ softmax = softmax_func(x, -1)
34
+ softmax = softmax.to(input_dtype)
35
+ expert_idx = torch.argsort(-softmax, dim=-1, stable=True)
36
+ expert_idx = expert_idx[:, :k]
37
+ y = torch.gather(softmax, index=expert_idx, dim=-1)
38
+ if finished_optional is not None:
39
+ finished_optional = finished_optional.view(finished_optional.shape[0], 1)
40
+ finished_optional = finished_optional.expand(-1, k)
41
+ expert_idx = torch.where(finished_optional, num_expert, expert_idx)
42
+ row_idx = torch.arange(y.shape[0] * y.shape[1]).reshape(y.shape[1], y.shape[0]).t()
43
+
44
+ return y, expert_idx, row_idx
@@ -30,6 +30,7 @@
30
30
  numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
31
31
  """
32
32
 
33
+ from collections import namedtuple
33
34
  import torch
34
35
  import numpy as np
35
36
  from einops import rearrange
@@ -54,6 +55,14 @@ GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即
54
55
  SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM"
55
56
 
56
57
 
58
+ FaForwardParams = namedtuple("FaForwardParams",
59
+ ["q", "k", "v", "drop_mask", "atten_mask", "pse", "scale", "keep_prob"])
60
+ FaBackwardParams = namedtuple("FaBackwardParams",
61
+ ["dx", "q", "k", "v", "softmax_res", "drop_mask", "pse", "scale", "keep_prob"])
62
+ RebuildSoftmaxParams = namedtuple("RebuildSoftmaxParams",
63
+ ["q", "k", "atten_mask", "pse", "scale", "softmax_max", "softmax_sum"])
64
+
65
+
57
66
  def softmax_forward(x):
58
67
  x_max = torch.max(x, dim=-1, keepdims=True)[0]
59
68
  x_sub = x.sub(x_max)
@@ -99,7 +108,15 @@ def calculate_qk(q, k, atten_mask, pse, scale):
99
108
  return qk
100
109
 
101
110
 
102
- def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_prob):
111
+ def fusion_attention_forward(forward_params):
112
+ q = forward_params.q
113
+ k = forward_params.k
114
+ v = forward_params.v
115
+ drop_mask = forward_params.drop_mask
116
+ atten_mask = forward_params.atten_mask
117
+ pse = forward_params.pse
118
+ scale = forward_params.scale
119
+ keep_prob = forward_params.keep_prob
103
120
  qk = calculate_qk(q, k, atten_mask, pse, scale)
104
121
  softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
105
122
  if drop_mask is None or len(drop_mask.shape) == 0:
@@ -110,7 +127,16 @@ def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_pr
110
127
  return y, softmax_max, softmax_sum
111
128
 
112
129
 
113
- def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob):
130
+ def fusion_attention_backward(backward_params):
131
+ dx = backward_params.dx
132
+ q = backward_params.q
133
+ k = backward_params.k
134
+ v = backward_params.v
135
+ softmax_res = backward_params.softmax_res
136
+ drop_mask = backward_params.drop_mask
137
+ pse = backward_params.pse
138
+ scale = backward_params.scale
139
+ keep_prob = backward_params.keep_prob
114
140
  dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
115
141
  if drop_mask is None or len(drop_mask.shape) == 0:
116
142
  drop_res = softmax_res.permute(0, 1, 3, 2)
@@ -368,11 +394,18 @@ def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
368
394
  return softmax_res
369
395
 
370
396
 
371
- def rebuild_softmax_by_max_sum(q, k, atten_mask, pse, scale, softmax_max, softmax_sum):
397
+ def rebuild_softmax_by_max_sum(softmax_params):
372
398
  """
373
399
  attention = softmax(QK^T/sqrt(d))V
374
400
  softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
375
401
  """
402
+ q = softmax_params.q
403
+ k = softmax_params.k
404
+ atten_mask = softmax_params.atten_mask
405
+ pse = softmax_params.pse
406
+ scale = softmax_params.scale
407
+ softmax_max = softmax_params.softmax_max
408
+ softmax_sum = softmax_params.softmax_sum
376
409
  logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
377
410
  qk = calculate_qk(q, k, atten_mask, pse, scale)
378
411
  if softmax_max.shape[-1] == 0:
@@ -502,10 +535,8 @@ def npu_fusion_attention(*args, **kwargs):
502
535
  key = convert_to_bnsd(key, n2, input_layout)
503
536
  value = convert_to_bnsd(value, n2, input_layout)
504
537
  k_new, v_new = generate_kv(key, value, n1, n2)
505
- out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new,
506
- drop_mask=None, atten_mask=atten_mask,
507
- pse=pse, scale=scale,
508
- keep_prob=keep_prob)
538
+ forward_params = FaForwardParams(query, k_new, v_new, None, atten_mask, pse, scale, keep_prob)
539
+ out_golden, softmax_max, softmax_sum = fusion_attention_forward(forward_params)
509
540
  if out_golden.dim() == 5:
510
541
  out_golden = out_golden.reshape(out_golden.size(0), out_golden.size(1) * out_golden.size(2), out_golden.size(3),
511
542
  out_golden.size(4))
@@ -546,9 +577,10 @@ def npu_fusion_attention_grad(*args, **kwargs):
546
577
  if SOFTMAX_BUILD_MODE == "QKV":
547
578
  softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
548
579
  else:
549
- softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
550
-
551
- dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
580
+ softmax_params = RebuildSoftmaxParams(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
581
+ softmax_res = rebuild_softmax_by_max_sum(softmax_params)
582
+ backward_params = FaBackwardParams(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
583
+ dq, dk, dv = fusion_attention_backward(backward_params)
552
584
 
553
585
  # N不等长适配by cdy
554
586
  if not (n1 == n2):
@@ -0,0 +1,21 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
18
+
19
+ def npu_sort_v2(x, dim=-1, descending=False, out=None):
20
+ y, _ = torch.sort(x, dim=dim, descending=descending)
21
+ return y
@@ -24,7 +24,8 @@ def parse_json_info_forward_backward(json_path):
24
24
  real_data_path = dump_json.get("dump_data_dir")
25
25
  dump_data = dump_json.get("data")
26
26
  if dump_data is None:
27
- raise ParseJsonException(ParseJsonException.InvalidDumpJson, "something wrong with dump, no data found in dump.json")
27
+ raise ParseJsonException(ParseJsonException.InvalidDumpJson,
28
+ "something wrong with dump, no data found in dump.json")
28
29
  if not dump_data:
29
30
  logger.warning("data field is empty, no overflow data found.")
30
31
 
@@ -18,6 +18,7 @@ import os
18
18
  import pickle
19
19
  import random
20
20
  import stat
21
+ import inspect
21
22
  from functools import wraps
22
23
 
23
24
  import numpy as np
@@ -105,8 +106,49 @@ def get_rank_if_initialized():
105
106
  raise DistributedNotInitializedError("torch distributed environment is not initialized")
106
107
 
107
108
 
108
- def seed_all(seed=1234, mode=False):
109
- check_seed_all(seed, mode)
109
+ def remove_dropout():
110
+ if torch.__version__ > "1.8":
111
+ logger.info_on_rank_0("For precision comparison, the probability p in the dropout method is set to 0.")
112
+ import torch.nn.functional as F
113
+ from torch import _VF
114
+ from torch.overrides import has_torch_function_unary, handle_torch_function
115
+
116
+ def function_dropout(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
117
+ inplace: bool = False) -> torch.Tensor:
118
+ if has_torch_function_unary(input_tensor):
119
+ return handle_torch_function(
120
+ function_dropout, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
121
+ if p < 0.0 or p > 1.0:
122
+ raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
123
+ return _VF.dropout_(input_tensor, 0., training) if inplace else _VF.dropout(input_tensor, 0., training)
124
+
125
+ def function_dropout2d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
126
+ inplace: bool = False) -> torch.Tensor:
127
+ if has_torch_function_unary(input_tensor):
128
+ return handle_torch_function(
129
+ function_dropout2d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
130
+ if p < 0.0 or p > 1.0:
131
+ raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
132
+ return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
133
+ 0., training)
134
+
135
+ def function_dropout3d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
136
+ inplace: bool = False) -> torch.Tensor:
137
+ if has_torch_function_unary(input_tensor):
138
+ return handle_torch_function(
139
+ function_dropout3d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
140
+ if p < 0.0 or p > 1.0:
141
+ raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
142
+ return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
143
+ 0., training)
144
+
145
+ F.dropout = function_dropout
146
+ F.dropout2d = function_dropout2d
147
+ F.dropout3d = function_dropout3d
148
+
149
+
150
+ def seed_all(seed=1234, mode=False, rm_dropout=True):
151
+ check_seed_all(seed, mode, rm_dropout)
110
152
  try:
111
153
  random.seed(seed)
112
154
  os.environ['PYTHONHASHSEED'] = str(seed)
@@ -126,6 +168,8 @@ def seed_all(seed=1234, mode=False):
126
168
  else:
127
169
  torch_npu.npu.manual_seed_all(seed)
128
170
  torch_npu.npu.manual_seed(seed)
171
+ if rm_dropout:
172
+ remove_dropout()
129
173
  except Exception as e:
130
174
  logger.error(f"There is an unexpected error while determinating randomness. {e}")
131
175
 
@@ -359,3 +403,73 @@ def load_api_data(api_data_bytes):
359
403
  except Exception as e:
360
404
  raise RuntimeError(f"load api_data from bytes failed") from e
361
405
  return buffer
406
+
407
+
408
+ def is_recomputation():
409
+ """Check if the current operation is in the re-computation phase.
410
+
411
+ This function inspects the current call stack to indicate whether the current operation is in the
412
+ re-computation phase. We use a blacklist mechanism, now supported megatron and mindspeed framework.
413
+ megatron: The 'backward' function is called by the 'torch/autograd/function.py' file.
414
+ mindspeed: The 'checkpoint_function_backward' function is called by the 'torch/autograd/function.py'
415
+ file or the custom module(use CheckpointWithoutOutput) with the 'recompute_fn' function is executed within the
416
+ 'torch/utils/checkpoint.py' file.
417
+
418
+ Returns:
419
+ bool: True if in the re-computation phase, False otherwise.
420
+ """
421
+ backward_function_indices = []
422
+ call_stack = inspect.stack()
423
+
424
+ # Identify the function 'backward' is being executed within the 'torch/_tensor.py' file.
425
+ for frame_info in call_stack:
426
+ if frame_info.function == "recompute_fn" and frame_info.filename.endswith('torch/utils/checkpoint.py'):
427
+ del call_stack
428
+ return True
429
+
430
+ # Identify indices in the call stack where the specific function is being executed
431
+ for idx, frame_info in enumerate(call_stack):
432
+ if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward':
433
+ backward_function_indices.append(idx)
434
+
435
+ # Check if the execution is within 'torch/autograd/function.py' file
436
+ for idx in backward_function_indices:
437
+ # The Megatron and MindSpeed L0&L1 scenes
438
+ if idx + 1 < len(call_stack) and call_stack[idx + 1].filename.endswith('torch/autograd/function.py'):
439
+ del call_stack
440
+ return True
441
+ # The latest MindSpeed L2 and ModelLink scenes
442
+ if idx + 2 < len(call_stack) and call_stack[idx + 2].filename.endswith('torch/autograd/function.py'):
443
+ del call_stack
444
+ return True
445
+
446
+ del call_stack
447
+ return False
448
+
449
+
450
+ def check_save_param(variable, name, save_backward):
451
+ # try catch this api to skip invalid call
452
+ if not isinstance(variable, (list, dict, torch.Tensor, int, float, str)):
453
+ logger.warning("PrecisionDebugger.save variable type not valid, "
454
+ "should be one of list, dict, torch.Tensor, int, float or string. "
455
+ "Skip current save process.")
456
+ raise ValueError
457
+ if not isinstance(name, str):
458
+ logger.warning("PrecisionDebugger.save name not valid, "
459
+ "should be string. "
460
+ "skip current save process.")
461
+ raise ValueError
462
+ if not isinstance(save_backward, bool):
463
+ logger.warning("PrecisionDebugger.save_backward name not valid, "
464
+ "should be bool. "
465
+ "Skip current save process.")
466
+ raise ValueError
467
+
468
+
469
+ def replace_last_occurrence(text, old, new):
470
+ if text is None:
471
+ return text
472
+ index = text.rfind(old)
473
+ if index != -1:
474
+ return text[:index] + text[index:].replace(old, new, 1)
475
+ return text
@@ -14,52 +14,40 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
- from msprobe.core.common.utils import CompareException, check_compare_param, \
18
- check_configuration_param, set_dump_path, get_dump_mode
19
- from msprobe.core.common.file_utils import create_directory
17
+
20
18
  from msprobe.core.common.exceptions import FileCheckException
19
+ from msprobe.core.common.file_utils import create_directory
20
+ from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
21
+ set_dump_path
22
+ from msprobe.core.compare.acc_compare import ModeConfig
23
+ from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json, set_stack_json_path
21
24
  from msprobe.pytorch.common.log import logger
22
- from msprobe.pytorch.compare.pt_compare import PTComparator
23
- from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
25
+ from msprobe.pytorch.compare.pt_compare import PTComparator, compare
24
26
 
25
27
 
26
28
  def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
27
- if kwargs.get('suffix'):
29
+ if kwargs.get("suffix"):
28
30
  logger.error("Argument 'suffix' is not supported for compare_distributed.")
29
31
  raise CompareException(CompareException.INVALID_PARAM_ERROR)
30
- stack_mode = kwargs.get('stack_mode', False)
31
- auto_analyze = kwargs.get('auto_analyze', True)
32
- fuzzy_match = kwargs.get('fuzzy_match', False)
33
- is_print_compare_log = kwargs.get('is_print_compare_log', True)
32
+ is_print_compare_log = kwargs.get("is_print_compare_log", True)
34
33
  # get the ranks and match by order
35
34
  npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
36
35
  bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
37
36
  if len(npu_ranks) != len(bench_ranks):
38
- logger.error('The number of ranks in the two runs are different. '
39
- 'Unable to match the ranks. Please use another folder to compare '
40
- 'or use compare() api and manually match the ranks.')
37
+ logger.error(
38
+ "The number of ranks in the two runs are different. "
39
+ "Unable to match the ranks. "
40
+ "Please use another folder to compare or use compare() api and manually match the ranks.")
41
41
  raise CompareException(CompareException.INVALID_PATH_ERROR)
42
42
  for nr, br in zip(npu_ranks, bench_ranks):
43
43
  npu_data_dir = os.path.join(npu_dump_dir, nr)
44
44
  bench_data_dir = os.path.join(bench_dump_dir, br)
45
45
  npu_path = extract_json(npu_data_dir, stack_json=False)
46
46
  bench_path = extract_json(bench_data_dir, stack_json=False)
47
- stack_path = extract_json(npu_data_dir, stack_json=True)
48
47
 
49
48
  dump_result_param = {
50
- 'npu_json_path': npu_path,
51
- 'bench_json_path': bench_path,
52
- 'stack_json_path': stack_path,
53
- 'is_print_compare_log': is_print_compare_log
49
+ "npu_json_path": npu_path,
50
+ "bench_json_path": bench_path,
51
+ "is_print_compare_log": is_print_compare_log
54
52
  }
55
- try:
56
- set_dump_path(dump_result_param)
57
- dump_mode = get_dump_mode(dump_result_param)
58
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match, is_print_compare_log)
59
- create_directory(output_path)
60
- check_compare_param(dump_result_param, output_path, dump_mode)
61
- except (CompareException, FileCheckException) as error:
62
- logger.error('Compare failed. Please check the arguments and do it again!')
63
- raise CompareException(error.code) from error
64
- pt_comparator = PTComparator()
65
- pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', dump_mode=dump_mode, **kwargs)
53
+ compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
@@ -14,19 +14,29 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os.path
17
+
17
18
  import torch
19
+
18
20
  from msprobe.core.common.const import FileCheckConst
19
- from msprobe.pytorch.common.log import logger
20
21
  from msprobe.core.common.exceptions import FileCheckException
21
- from msprobe.core.compare.acc_compare import Comparator
22
- from msprobe.core.common.utils import check_configuration_param, check_compare_param, \
23
- CompareException, set_dump_path, get_dump_mode
24
22
  from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml
23
+ from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
24
+ set_dump_path
25
+ from msprobe.core.compare.acc_compare import Comparator, ModeConfig
26
+ from msprobe.core.compare.utils import set_stack_json_path
27
+ from msprobe.pytorch.common.log import logger
25
28
  from msprobe.pytorch.common.utils import load_pt
26
29
 
27
30
 
28
- class PTComparator (Comparator):
29
- def __init__(self, data_mapping=None):
31
+ class PTComparator(Comparator):
32
+ def __init__(self, mode_config, data_mapping=None):
33
+ super().__init__(mode_config)
34
+
35
+ self.stack_mode = mode_config.stack_mode
36
+ self.auto_analyze = mode_config.auto_analyze
37
+ self.fuzzy_match = mode_config.fuzzy_match
38
+ self.dump_mode = mode_config.dump_mode
39
+
30
40
  self.frame_name = PTComparator.__name__
31
41
  self.data_mapping = data_mapping
32
42
  if isinstance(self.data_mapping, str) or self.data_mapping is None:
@@ -37,23 +47,24 @@ class PTComparator (Comparator):
37
47
  raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
38
48
  f"{type(self.data_mapping)}")
39
49
 
40
- def load_mapping_file(self, mapping_file):
50
+ @staticmethod
51
+ def load_mapping_file(mapping_file):
41
52
  if isinstance(mapping_file, str):
42
53
  mapping_dict = load_yaml(mapping_file)
43
54
  else:
44
55
  mapping_dict = {}
45
56
  return mapping_dict
46
-
57
+
47
58
  def read_npy_data(self, dir_path, file_name):
48
59
  if not file_name:
49
60
  return None
50
61
  data_path = os.path.join(dir_path, file_name)
51
62
  path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
52
- FileCheckConst.PT_SUFFIX, False)
63
+ FileCheckConst.PT_SUFFIX, False)
53
64
  data_path = path_checker.common_check()
54
65
  try:
55
- data_value = load_pt(data_path,
56
- to_cpu=True).detach() # detach because numpy can not process gradient information
66
+ # detach because numpy can not process gradient information
67
+ data_value = load_pt(data_path, to_cpu=True).detach()
57
68
  except RuntimeError as e:
58
69
  # 这里捕获 load_pt 中抛出的异常
59
70
  logger.error(f"Failed to load the .pt file at {data_path}.")
@@ -65,20 +76,29 @@ class PTComparator (Comparator):
65
76
  if data_value.dtype == torch.bfloat16:
66
77
  data_value = data_value.to(torch.float32)
67
78
  data_value = data_value.numpy()
68
- return data_value
69
-
70
-
71
- def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False, **kwargs):
79
+ return data_value
80
+
81
+
82
+ def compare(input_param, output_path, **kwargs):
72
83
  try:
84
+ auto_analyze = kwargs.get('auto_analyze', True)
85
+ fuzzy_match = kwargs.get('fuzzy_match', False)
86
+ data_mapping = kwargs.get('data_mapping', None)
87
+ suffix = kwargs.get('suffix', '')
88
+
73
89
  set_dump_path(input_param)
74
90
  dump_mode = get_dump_mode(input_param)
91
+ if "stack_json_path" in input_param:
92
+ stack_mode = kwargs.get('stack_mode', False)
93
+ else:
94
+ stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param
75
95
  check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
76
96
  create_directory(output_path)
77
- check_compare_param(input_param, output_path, dump_mode)
78
- data_mapping = kwargs.get('data_mapping', None)
97
+ check_compare_param(input_param, output_path, dump_mode, stack_mode)
79
98
  except (CompareException, FileCheckException) as error:
80
99
  logger.error('Compare failed. Please check the arguments and do it again!')
81
100
  raise CompareException(error.code) from error
82
- pt_comparator = PTComparator(data_mapping)
83
- pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
84
- auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, dump_mode=dump_mode)
101
+
102
+ mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode)
103
+ pt_comparator = PTComparator(mode_config, data_mapping)
104
+ pt_comparator.compare_core(input_param, output_path, suffix=suffix)
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -26,7 +26,7 @@ class DebuggerConfig:
26
26
  self.task = task or common_config.task or Const.STATISTICS
27
27
  self.rank = common_config.rank if common_config.rank else []
28
28
  self.step = common_config.step if common_config.step else []
29
- self.level = level or common_config.level or "L1"
29
+ self.level = level or common_config.level or Const.LEVEL_L1
30
30
  self.enable_dataloader = common_config.enable_dataloader
31
31
  self.scope = task_config.scope if task_config.scope else []
32
32
  self.list = task_config.list if task_config.list else []
@@ -34,10 +34,7 @@ class DebuggerConfig:
34
34
  self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
35
35
  self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
36
36
  self.framework = Const.PT_FRAMEWORK
37
-
38
- if self.level == Const.LEVEL_L2:
39
- self.is_backward_kernel_dump = False
40
- self._check_and_adjust_config_with_l2()
37
+ self.async_dump = common_config.async_dump if common_config.async_dump else False
41
38
 
42
39
  if self.task == Const.FREE_BENCHMARK:
43
40
  self.fuzz_device = task_config.fuzz_device
@@ -64,6 +61,10 @@ class DebuggerConfig:
64
61
 
65
62
  self.check()
66
63
 
64
+ if self.level == Const.LEVEL_L2:
65
+ self.is_backward_kernel_dump = False
66
+ self._check_and_adjust_config_with_l2()
67
+
67
68
  def check_kwargs(self):
68
69
  if self.task and self.task not in Const.TASK_LIST:
69
70
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
@@ -74,29 +75,53 @@ class DebuggerConfig:
74
75
  if not self.dump_path:
75
76
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
76
77
  f"The dump_path not found.")
78
+ if not isinstance(self.async_dump, bool):
79
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
80
+ f"The parameters async_dump should be bool.")
81
+ if self.async_dump and self.task == Const.TENSOR and not self.list:
82
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
83
+ f"The parameters async_dump is true in tensor task, the parameters list cannot be "
84
+ f"empty.")
85
+ if self.task == Const.STRUCTURE and self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
86
+ logger.warning_on_rank_0(
87
+ f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
88
+ f"If not, the default level is {Const.LEVEL_MIX}."
89
+ )
90
+ self.level = Const.LEVEL_MIX
77
91
 
78
92
  def check(self):
79
93
  self.check_kwargs()
80
94
  return True
81
95
 
82
96
  def check_model(self, instance, start_model):
83
- if self.level not in ["L0", "mix"]:
97
+ if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
84
98
  if instance.model is not None or start_model is not None:
85
- logger.warning_on_rank_0(
99
+ logger.info_on_rank_0(
86
100
  f"The current level is not L0 or mix level, so the model parameters will not be used.")
87
101
  return
88
- if start_model is None:
89
- if instance.model is None:
90
- logger.error_on_rank_0(
91
- f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' argument.")
92
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
102
+ if start_model is None and instance.model is None:
103
+ logger.error_on_rank_0(
104
+ f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' parameter.")
105
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
106
+
107
+ instance.model = start_model if start_model is not None else instance.model
108
+ if isinstance(instance.model, torch.nn.Module):
93
109
  return
94
- if isinstance(start_model, torch.nn.Module):
95
- instance.model = start_model
110
+
111
+ error_model = None
112
+ if isinstance(instance.model, (list, tuple)):
113
+ for model in instance.model:
114
+ if not isinstance(model, torch.nn.Module):
115
+ error_model = model
116
+ break
96
117
  else:
97
- logger.error_on_rank_0(f"The 'model' parameter of start must be a torch.nn.Module type.")
118
+ error_model = instance.model
119
+
120
+ if error_model is not None:
121
+ error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] "
122
+ f"type, currently there is a {type(error_model)} type.")
98
123
  raise MsprobeException(
99
- MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
124
+ MsprobeException.INVALID_PARAM_ERROR, error_info)
100
125
 
101
126
  def _check_and_adjust_config_with_l2(self):
102
127
  if self.scope: