mindstudio-probe 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/RECORD +85 -66
  3. msprobe/README.md +2 -2
  4. msprobe/core/common/const.py +34 -9
  5. msprobe/core/common/inplace_ops.yaml +1 -0
  6. msprobe/core/common/utils.py +14 -0
  7. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  8. msprobe/core/compare/merge_result/merge_result.py +8 -7
  9. msprobe/core/compare/merge_result/utils.py +81 -0
  10. msprobe/core/compare/utils.py +10 -0
  11. msprobe/core/data_dump/data_collector.py +58 -13
  12. msprobe/core/data_dump/data_processor/base.py +92 -8
  13. msprobe/core/data_dump/data_processor/factory.py +3 -0
  14. msprobe/core/data_dump/data_processor/mindspore_processor.py +17 -4
  15. msprobe/core/data_dump/data_processor/pytorch_processor.py +58 -7
  16. msprobe/core/data_dump/json_writer.py +26 -8
  17. msprobe/docs/01.installation.md +25 -0
  18. msprobe/docs/02.config_introduction.md +14 -12
  19. msprobe/docs/03.config_examples.md +24 -0
  20. msprobe/docs/05.data_dump_PyTorch.md +34 -15
  21. msprobe/docs/06.data_dump_MindSpore.md +45 -22
  22. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -2
  23. msprobe/docs/19.monitor.md +257 -260
  24. msprobe/docs/21.visualization_PyTorch.md +10 -0
  25. msprobe/docs/22.visualization_MindSpore.md +11 -0
  26. msprobe/docs/27.dump_json_instruction.md +24 -20
  27. msprobe/docs/28.debugger_save_instruction.md +94 -0
  28. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  29. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  30. msprobe/mindspore/__init__.py +1 -0
  31. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +26 -6
  32. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  33. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  34. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  35. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  36. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  37. msprobe/mindspore/common/utils.py +20 -2
  38. msprobe/mindspore/debugger/debugger_config.py +25 -2
  39. msprobe/mindspore/debugger/precision_debugger.py +25 -6
  40. msprobe/mindspore/dump/hook_cell/api_registry.py +2 -0
  41. msprobe/mindspore/dump/jit_dump.py +7 -6
  42. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  43. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  44. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  45. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  46. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  47. msprobe/mindspore/monitor/features.py +63 -0
  48. msprobe/mindspore/monitor/module_hook.py +821 -0
  49. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  50. msprobe/mindspore/monitor/utils.py +267 -0
  51. msprobe/mindspore/ms_config.py +8 -2
  52. msprobe/mindspore/service.py +95 -21
  53. msprobe/pytorch/__init__.py +0 -1
  54. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  55. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  56. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  57. msprobe/pytorch/bench_functions/mish.py +21 -0
  58. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  59. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  60. msprobe/pytorch/common/utils.py +71 -0
  61. msprobe/pytorch/debugger/debugger_config.py +19 -9
  62. msprobe/pytorch/debugger/precision_debugger.py +14 -0
  63. msprobe/pytorch/dump/module_dump/module_processer.py +10 -30
  64. msprobe/pytorch/function_factory.py +7 -1
  65. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  66. msprobe/pytorch/hook_module/wrap_distributed.py +4 -0
  67. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  68. msprobe/pytorch/monitor/csv2tb.py +10 -12
  69. msprobe/pytorch/monitor/module_hook.py +123 -104
  70. msprobe/pytorch/monitor/module_metric.py +6 -6
  71. msprobe/pytorch/monitor/optimizer_collect.py +45 -63
  72. msprobe/pytorch/monitor/utils.py +8 -43
  73. msprobe/pytorch/pt_config.py +19 -22
  74. msprobe/pytorch/service.py +103 -24
  75. msprobe/visualization/builder/graph_builder.py +31 -5
  76. msprobe/visualization/builder/msprobe_adapter.py +7 -5
  77. msprobe/visualization/graph/base_node.py +3 -2
  78. msprobe/visualization/graph/distributed_analyzer.py +80 -3
  79. msprobe/visualization/graph/node_op.py +4 -2
  80. msprobe/visualization/graph_service.py +3 -4
  81. msprobe/visualization/utils.py +10 -2
  82. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  83. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  84. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  85. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,215 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import namedtuple
17
+ import torch
18
+
19
+
20
+ VarParams = namedtuple('VarParams', ['var', 'lr_t', 'm_t', 'beta1_broad', 'grad', 'epsilon', 'v_t'])
21
+
22
+
23
+ def _output_m_compute(m, beta1_broad, grad):
24
+ """
25
+ _output_m_compute
26
+ do compute m_t = m + (beta1 - 1) * (m - grad)
27
+ """
28
+ input_dtype = m.dtype
29
+
30
+ sneg_one = torch.ones((1), dtype=input_dtype) * -1
31
+ sneg_one = sneg_one.to(beta1_broad.device)
32
+
33
+ # `formula; beta1 -1`
34
+ vsub_beta1_1 = torch.add(beta1_broad, sneg_one)
35
+
36
+ # `formula; m - grad`
37
+ vsub_m_grad = torch.sub(m, grad)
38
+
39
+ # `formula; (beta1 - 1) * (m - grad)`
40
+ vmul_m = torch.mul(vsub_beta1_1, vsub_m_grad)
41
+
42
+ # `formula; m_t = m + (beta1 - 1) * (m - grad)`
43
+ m_t = torch.add(m, vmul_m)
44
+
45
+ return m_t
46
+
47
+
48
+ def _output_v_compute(v, beta2, grad):
49
+ """
50
+ _output_v_compute
51
+ do compute v_t = v + (1 - beta2)*(grad*grad -v)
52
+ """
53
+ input_dtype = v.dtype
54
+
55
+ sneg_one = torch.ones((1), dtype=input_dtype) * -1
56
+
57
+ # `formula; broadcast beta2 to vector`
58
+ beta2_tensor = torch.tensor(beta2, dtype=input_dtype)
59
+ beta2_broad = beta2_tensor.expand_as(v)
60
+
61
+ # `formula; beta2 - 1`
62
+ vsub_beta2_1 = torch.add(beta2_broad, sneg_one)
63
+ vsub_beta2_1 = vsub_beta2_1.to(v.device)
64
+
65
+ # `formula; grad * grad`
66
+ vmul_grad_grad = torch.mul(grad, grad)
67
+
68
+ # `formula; (v - grad*grad)`
69
+ vsub_v_grad = torch.sub(v, vmul_grad_grad)
70
+
71
+ # `formula; (beta2 -1) * (v - grad * grad)`
72
+ vmul_grad = torch.mul(vsub_beta2_1, vsub_v_grad)
73
+
74
+ # `formula; v_t = v + (beta2 - 1) * (v - grad * grad)`
75
+ v_t = torch.add(v, vmul_grad)
76
+
77
+ return v_t
78
+
79
+
80
+ def _inner_lr_compute(lr, beta2_power, beta1_power, compute_shape_tensor):
81
+ """
82
+ _inner_lr_compute
83
+ `formula; lr_t = learning_rate * (sqrt(1-beta2_power)) / (1 - beta1_power)`
84
+ """
85
+
86
+ input_dtype = compute_shape_tensor.dtype
87
+
88
+ s_one = torch.ones((1), dtype=input_dtype)
89
+
90
+ s_neg_one = torch.ones((1), dtype=input_dtype) * -1
91
+
92
+ # `formula; (1 - beta2_power)`
93
+ v_neg_beta2_power = torch.mul(beta2_power, s_neg_one)
94
+ v_add_beta2_power = torch.add(v_neg_beta2_power, s_one)
95
+
96
+ # `formula; sqrt(1 - beta2_power)`
97
+ v_sqrt_beta2_power = torch.sqrt(v_add_beta2_power)
98
+
99
+ # `formula; (1 - beta1_power)`
100
+ v_neg_beta1_power = torch.mul(beta1_power, s_neg_one)
101
+ v_add_beta1_power = torch.add(v_neg_beta1_power, s_one)
102
+
103
+ # `formula; learning_rate * (sqrt(1-beta2_power)`
104
+ res = torch.mul(lr, v_sqrt_beta2_power)
105
+
106
+ # `formula; learning_rate*(sqrt(1-beta2_power))/(1-beta1_power)`
107
+ res = torch.div(res, v_add_beta1_power)
108
+ return res.expand_as(compute_shape_tensor)
109
+
110
+
111
+ def _inner_eps_add_sqrt_vt_compute(epsilon, v_t):
112
+ """
113
+ (epsilon + sqrt(v_t) )
114
+ """
115
+ # `formula; sqrt(v_t)`
116
+ sqrt_vt = torch.sqrt(v_t)
117
+
118
+ # `formula; broadcast epsilon to vector`
119
+ input_dtype = v_t.dtype
120
+ epsilon_tensor = torch.tensor(epsilon, dtype=input_dtype)
121
+ epsilon_broad = epsilon_tensor.expand_as(v_t)
122
+ epsilon_broad = epsilon_broad.to(sqrt_vt.device)
123
+
124
+ # `formula; epsilon + sqrt(v_t)`
125
+ v_add_sqrt_v = torch.add(sqrt_vt, epsilon_broad)
126
+
127
+ return v_add_sqrt_v
128
+
129
+
130
+ def _output_var_t_compute_use_nesterov(varparams):
131
+ """
132
+ _output_var_t_compute_use_nesterov
133
+ `formula; var_t = var - lr_t * (m_t * beta1 + (1 - beta1) * grad) / (epsilon + sqrt(v_t))`
134
+ `formula; var_t = var - lr_t * (m_t * beta1 + (1 - beta1) * grad) / (epsilon + sqrt(v_t))`
135
+ """
136
+ var = varparams.var
137
+ lr_t = varparams.lr_t
138
+ m_t = varparams.m_t
139
+ beta1_broad = varparams.beta1_broad
140
+ grad = varparams.grad
141
+ epsilon = varparams.epsilon
142
+ v_t = varparams.v_t
143
+
144
+ input_dtype = var.dtype
145
+
146
+ s_one = torch.ones((1), dtype=input_dtype)
147
+
148
+ s_neg_one = torch.ones((1), dtype=input_dtype) * -1
149
+
150
+ # `formula; m_t * beta1`
151
+ v_muls_mt_beta1 = torch.mul(m_t, beta1_broad)
152
+
153
+ # `formula; 1 -beta1`
154
+ v_neg_beta1 = torch.mul(beta1_broad, s_neg_one)
155
+ vsub_1_beta1 = torch.add(v_neg_beta1, s_one)
156
+
157
+ # `formula; (1-beta1)* grad`
158
+ v_mul_grad = torch.mul(vsub_1_beta1, grad)
159
+
160
+ # `formula; (m_t*beta1 + (1 - beta1)*grad)`
161
+ v_div_left = torch.add(v_muls_mt_beta1, v_mul_grad)
162
+
163
+ # `formula; lr_t * (m_t*beta1 + (1 - beta1) * grad)`
164
+ # broadcast lr_t to vector
165
+
166
+ lrt_broad = lr_t.expand_as(var)
167
+ v_mul_left = torch.mul(lrt_broad, v_div_left)
168
+
169
+ # `formula; (epsilon + sqrt(v_t))`
170
+ v_add_sqrt_v = _inner_eps_add_sqrt_vt_compute(epsilon, v_t)
171
+
172
+ # `formula; lr_t * (m_t*beta1 + (1-beta1)*grad / (epsilon + sqrt(v_t))`
173
+ v_div_res = torch.div(v_mul_left, v_add_sqrt_v)
174
+
175
+ # `formula; var - lr_t * (m_t*beta1 + (1-beta1)*grad) / (epsilon + sqrt(v_t))`
176
+ v_t = torch.sub(var, v_div_res)
177
+
178
+ return v_t
179
+
180
+
181
+ def _output_var_t_compute(var, lr_t, m_t, epsilon, v_t):
182
+ """
183
+ _output_var_t_compute
184
+ `var_t = var - lr_t * m_t / (epsilon + sqrt(v_t))`
185
+ """
186
+ # `formula; lr_t * m_t`
187
+ lr_t = lr_t.to(m_t.device)
188
+ v_mul_left = torch.mul(lr_t, m_t)
189
+
190
+ # `formula; (epsilon + sqrt(v_t))`
191
+ v_add_sqrt_v = _inner_eps_add_sqrt_vt_compute(epsilon, v_t)
192
+
193
+ # `formula; lr_t * m_t /(epsilon + sqrt(v_t))`
194
+ v_div_res = torch.div(v_mul_left, v_add_sqrt_v)
195
+
196
+ # `formula; var - lr_t * m_t / (epsilon + sqrt(v_t))`
197
+ v_t = torch.sub(var, v_div_res)
198
+
199
+ return v_t
200
+
201
+
202
+ def npu_apply_adam(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, use_locking, use_nesterov, out):
203
+ var, m, v = out
204
+ input_dtype = m.dtype
205
+ beta1_tensor = torch.tensor(beta1, dtype=input_dtype).to(m.device)
206
+ beta1_broad = beta1_tensor.expand_as(m)
207
+ m_t = _output_m_compute(m, beta1_broad, grad)
208
+ v_t = _output_v_compute(v, beta2, grad)
209
+ lr_t = _inner_lr_compute(lr, beta2_power, beta1_power, grad)
210
+ if use_nesterov:
211
+ var_params = VarParams(var, lr_t, m_t, beta1_broad, grad, epsilon, v_t)
212
+ var_t = _output_var_t_compute_use_nesterov(var_params)
213
+ else:
214
+ var_t = _output_var_t_compute(var, lr_t, m_t, epsilon, v_t)
215
+ return var_t, m_t, v_t
@@ -0,0 +1,27 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
18
+
19
+ def npu_group_norm_silu(x, gama, beta, group, eps):
20
+ if len(x.shape) != 4:
21
+ raise ValueError("x shape should be (N, C, H, W)")
22
+ res = torch.ops.aten.native_group_norm(x, gama, beta, x.shape[0], x.shape[1], x.shape[2] * x.shape[3], group, eps)
23
+ res = list(res)
24
+ if not res:
25
+ raise ValueError("run native_group_norm failed")
26
+ res[0] = torch.nn.functional.silu(res[0])
27
+ return res
@@ -0,0 +1,21 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
18
+
19
+ def npu_mish(x):
20
+ mish = torch.nn.Mish()
21
+ return mish(x)
@@ -0,0 +1,44 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import numpy as np
18
+
19
+
20
+ def softmax_func(x, axis=None):
21
+ x = x.float()
22
+ x_max = x.max(dim=axis, keepdims=True).values
23
+ x_sub = x - x_max
24
+ y = torch.exp(x_sub)
25
+ x_sum = y.sum(dim=axis, keepdims=True)
26
+ ans = 0 if (x_sum == 0).any() else y / x_sum
27
+ return ans
28
+
29
+
30
+ def npu_moe_gating_top_k_softmax(x, finished_optional, k):
31
+ input_dtype = x.dtype
32
+ num_expert = x.shape[-1]
33
+ softmax = softmax_func(x, -1)
34
+ softmax = softmax.to(input_dtype)
35
+ expert_idx = torch.argsort(-softmax, dim=-1, stable=True)
36
+ expert_idx = expert_idx[:, :k]
37
+ y = torch.gather(softmax, index=expert_idx, dim=-1)
38
+ if finished_optional is not None:
39
+ finished_optional = finished_optional.view(finished_optional.shape[0], 1)
40
+ finished_optional = finished_optional.expand(-1, k)
41
+ expert_idx = torch.where(finished_optional, num_expert, expert_idx)
42
+ row_idx = torch.arange(y.shape[0] * y.shape[1]).reshape(y.shape[1], y.shape[0]).t()
43
+
44
+ return y, expert_idx, row_idx
@@ -0,0 +1,21 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
18
+
19
+ def npu_sort_v2(x, dim=-1, descending=False, out=None):
20
+ y, _ = torch.sort(x, dim=dim, descending=descending)
21
+ return y
@@ -18,6 +18,7 @@ import os
18
18
  import pickle
19
19
  import random
20
20
  import stat
21
+ import inspect
21
22
  from functools import wraps
22
23
 
23
24
  import numpy as np
@@ -402,3 +403,73 @@ def load_api_data(api_data_bytes):
402
403
  except Exception as e:
403
404
  raise RuntimeError(f"load api_data from bytes failed") from e
404
405
  return buffer
406
+
407
+
408
+ def is_recomputation():
409
+ """Check if the current operation is in the re-computation phase.
410
+
411
+ This function inspects the current call stack to indicate whether the current operation is in the
412
+ re-computation phase. We use a blacklist mechanism, now supported megatron and mindspeed framework.
413
+ megatron: The 'backward' function is called by the 'torch/autograd/function.py' file.
414
+ mindspeed: The 'checkpoint_function_backward' function is called by the 'torch/autograd/function.py'
415
+ file or the custom module(use CheckpointWithoutOutput) with the 'recompute_fn' function is executed within the
416
+ 'torch/utils/checkpoint.py' file.
417
+
418
+ Returns:
419
+ bool: True if in the re-computation phase, False otherwise.
420
+ """
421
+ backward_function_indices = []
422
+ call_stack = inspect.stack()
423
+
424
+ # Identify the function 'backward' is being executed within the 'torch/_tensor.py' file.
425
+ for frame_info in call_stack:
426
+ if frame_info.function == "recompute_fn" and frame_info.filename.endswith('torch/utils/checkpoint.py'):
427
+ del call_stack
428
+ return True
429
+
430
+ # Identify indices in the call stack where the specific function is being executed
431
+ for idx, frame_info in enumerate(call_stack):
432
+ if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward':
433
+ backward_function_indices.append(idx)
434
+
435
+ # Check if the execution is within 'torch/autograd/function.py' file
436
+ for idx in backward_function_indices:
437
+ # The Megatron and MindSpeed L0&L1 scenes
438
+ if idx + 1 < len(call_stack) and call_stack[idx + 1].filename.endswith('torch/autograd/function.py'):
439
+ del call_stack
440
+ return True
441
+ # The latest MindSpeed L2 and ModelLink scenes
442
+ if idx + 2 < len(call_stack) and call_stack[idx + 2].filename.endswith('torch/autograd/function.py'):
443
+ del call_stack
444
+ return True
445
+
446
+ del call_stack
447
+ return False
448
+
449
+
450
+ def check_save_param(variable, name, save_backward):
451
+ # try catch this api to skip invalid call
452
+ if not isinstance(variable, (list, dict, torch.Tensor, int, float, str)):
453
+ logger.warning("PrecisionDebugger.save variable type not valid, "
454
+ "should be one of list, dict, torch.Tensor, int, float or string. "
455
+ "Skip current save process.")
456
+ raise ValueError
457
+ if not isinstance(name, str):
458
+ logger.warning("PrecisionDebugger.save name not valid, "
459
+ "should be string. "
460
+ "skip current save process.")
461
+ raise ValueError
462
+ if not isinstance(save_backward, bool):
463
+ logger.warning("PrecisionDebugger.save_backward name not valid, "
464
+ "should be bool. "
465
+ "Skip current save process.")
466
+ raise ValueError
467
+
468
+
469
+ def replace_last_occurrence(text, old, new):
470
+ if text is None:
471
+ return text
472
+ index = text.rfind(old)
473
+ if index != -1:
474
+ return text[:index] + text[index:].replace(old, new, 1)
475
+ return text
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -26,7 +26,7 @@ class DebuggerConfig:
26
26
  self.task = task or common_config.task or Const.STATISTICS
27
27
  self.rank = common_config.rank if common_config.rank else []
28
28
  self.step = common_config.step if common_config.step else []
29
- self.level = level or common_config.level or "L1"
29
+ self.level = level or common_config.level or Const.LEVEL_L1
30
30
  self.enable_dataloader = common_config.enable_dataloader
31
31
  self.scope = task_config.scope if task_config.scope else []
32
32
  self.list = task_config.list if task_config.list else []
@@ -36,10 +36,6 @@ class DebuggerConfig:
36
36
  self.framework = Const.PT_FRAMEWORK
37
37
  self.async_dump = common_config.async_dump if common_config.async_dump else False
38
38
 
39
- if self.level == Const.LEVEL_L2:
40
- self.is_backward_kernel_dump = False
41
- self._check_and_adjust_config_with_l2()
42
-
43
39
  if self.task == Const.FREE_BENCHMARK:
44
40
  self.fuzz_device = task_config.fuzz_device
45
41
  self.handler_type = task_config.handler_type
@@ -65,6 +61,10 @@ class DebuggerConfig:
65
61
 
66
62
  self.check()
67
63
 
64
+ if self.level == Const.LEVEL_L2:
65
+ self.is_backward_kernel_dump = False
66
+ self._check_and_adjust_config_with_l2()
67
+
68
68
  def check_kwargs(self):
69
69
  if self.task and self.task not in Const.TASK_LIST:
70
70
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
@@ -78,6 +78,16 @@ class DebuggerConfig:
78
78
  if not isinstance(self.async_dump, bool):
79
79
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
80
80
  f"The parameters async_dump should be bool.")
81
+ if self.async_dump and self.task == Const.TENSOR and not self.list:
82
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
83
+ f"The parameters async_dump is true in tensor task, the parameters list cannot be "
84
+ f"empty.")
85
+ if self.task == Const.STRUCTURE and self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
86
+ logger.warning_on_rank_0(
87
+ f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
88
+ f"If not, the default level is {Const.LEVEL_MIX}."
89
+ )
90
+ self.level = Const.LEVEL_MIX
81
91
 
82
92
  def check(self):
83
93
  self.check_kwargs()
@@ -93,10 +103,10 @@ class DebuggerConfig:
93
103
  logger.error_on_rank_0(
94
104
  f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' parameter.")
95
105
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
96
-
106
+
97
107
  instance.model = start_model if start_model is not None else instance.model
98
108
  if isinstance(instance.model, torch.nn.Module):
99
- return
109
+ return
100
110
 
101
111
  error_model = None
102
112
  if isinstance(instance.model, (list, tuple)):
@@ -108,7 +118,7 @@ class DebuggerConfig:
108
118
  error_model = instance.model
109
119
 
110
120
  if error_model is not None:
111
- error_info = (f"The 'model' parameter must be a torch.nn.Moudle or list[torch.nn.Moudle] "
121
+ error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] "
112
122
  f"type, currently there is a {type(error_model)} type.")
113
123
  raise MsprobeException(
114
124
  MsprobeException.INVALID_PARAM_ERROR, error_info)
@@ -21,6 +21,7 @@ from msprobe.core.common.exceptions import MsprobeException
21
21
  from msprobe.core.common.file_utils import FileChecker
22
22
  from msprobe.core.common.utils import get_real_step_or_rank
23
23
  from msprobe.pytorch.common.log import logger
24
+ from msprobe.pytorch.common.utils import check_save_param
24
25
  from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
25
26
  from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper
26
27
  from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
@@ -158,6 +159,19 @@ class PrecisionDebugger:
158
159
  return
159
160
  cls._instance.gm.monitor(model)
160
161
 
162
+ @classmethod
163
+ def save(cls, variable, name, save_backward=True):
164
+ instance = cls._instance
165
+ if not instance:
166
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
167
+ if instance.task not in [Const.TENSOR, Const.STATISTICS] or instance.config.level != Const.LEVEL_DEBUG:
168
+ return
169
+ try:
170
+ check_save_param(variable, name, save_backward)
171
+ except ValueError:
172
+ return
173
+ instance.service.save(variable, name, save_backward)
174
+
161
175
 
162
176
  def module_dump(module, dump_name):
163
177
  if not isinstance(module, torch.nn.Module):
@@ -19,6 +19,7 @@ import torch
19
19
  from msprobe.core.common.const import Const
20
20
  from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
21
21
  from msprobe.pytorch.common.log import logger
22
+ from msprobe.pytorch.common.utils import replace_last_occurrence
22
23
  from torch.utils.checkpoint import checkpoint as origin_checkpoint
23
24
  from torch.utils.checkpoint import set_checkpoint_early_stop
24
25
  from torch.utils.hooks import BackwardHook
@@ -45,29 +46,8 @@ class ModuleProcesser:
45
46
  self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
46
47
  BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
47
48
  BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
48
- BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook)
49
49
  replace_checkpoint()
50
50
 
51
- @staticmethod
52
- def filter_tensor_and_tuple(func):
53
- @wraps(func)
54
- def wrap_by_filter_tensor_and_tuple(*args, **kwargs):
55
- # setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是解析非tensor数据的属性,对tensor属性挂hook
56
- # setup_output_hook定义为setup_output_hook(self, args),因此处理第二个位置参数,即*args[1]
57
- if not isinstance(args[1], (torch.Tensor, tuple)):
58
- for item_str in dir(args[1]):
59
- item = getattr(args[1], item_str)
60
- # 处理tensor或者只包含tensor的元组
61
- if isinstance(item, torch.Tensor) or \
62
- (isinstance(item, tuple) and all(isinstance(x, torch.Tensor) for x in item)):
63
- args_new = (args[0], item)
64
- result = func(*args_new, **kwargs)
65
- setattr(args[1], item_str, result)
66
- return args[1]
67
- return func(*args, **kwargs)
68
-
69
- return wrap_by_filter_tensor_and_tuple
70
-
71
51
  @staticmethod
72
52
  def clone_return_value(func):
73
53
  @wraps(func)
@@ -81,11 +61,11 @@ class ModuleProcesser:
81
61
  def clone_if_tensor(result):
82
62
  if isinstance(result, torch.Tensor):
83
63
  return result.clone()
84
- elif isinstance(result, tuple):
64
+ elif type(result) is tuple:
85
65
  return tuple(ModuleProcesser.clone_if_tensor(x) for x in result)
86
- elif isinstance(result, list):
66
+ elif type(result) is list:
87
67
  return list(ModuleProcesser.clone_if_tensor(x) for x in result)
88
- elif isinstance(result, dict):
68
+ elif type(result) is dict:
89
69
  return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()}
90
70
  else:
91
71
  return result
@@ -103,7 +83,7 @@ class ModuleProcesser:
103
83
  return hasattr(module, '_backward_hooks') and \
104
84
  len(module._backward_hooks) > 0 and \
105
85
  module._is_full_backward_hook is False
106
-
86
+
107
87
  @staticmethod
108
88
  def get_modules_and_names(models):
109
89
  modules_and_names_with_index = {}
@@ -130,8 +110,8 @@ class ModuleProcesser:
130
110
  if module == model:
131
111
  continue
132
112
  module_index = (index + Const.SEP) if index != "-1" else ""
133
- prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index +
134
- name + Const.SEP + module.__class__.__name__ + Const.SEP)
113
+ prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index +
114
+ name + Const.SEP + module.__class__.__name__ + Const.SEP)
135
115
  pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = build_hook(
136
116
  BaseScope.Module_Type_Module,
137
117
  prefix_name
@@ -203,9 +183,9 @@ class ModuleProcesser:
203
183
  if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
204
184
  module.mindstudio_reserved_name = []
205
185
  module.mindstudio_reserved_name.append(full_name)
206
- forward_full_name = full_name.replace(Const.BACKWARD, Const.FORWARD)
207
- ModuleProcesser.module_node[full_name] = ModuleProcesser.module_node[forward_full_name].replace(
208
- Const.FORWARD, Const.BACKWARD) if ModuleProcesser.module_node[forward_full_name] else None
186
+ forward_full_name = replace_last_occurrence(full_name, Const.BACKWARD, Const.FORWARD)
187
+ ModuleProcesser.module_node[full_name] = replace_last_occurrence(
188
+ ModuleProcesser.module_node.get(forward_full_name), Const.FORWARD, Const.BACKWARD)
209
189
  ModuleProcesser.api_parent_node = None
210
190
  if self.scope:
211
191
  self.scope.begin_module(full_name)
@@ -27,6 +27,11 @@ from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotar
27
27
  from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \
28
28
  npu_scaled_masked_softmax_backward
29
29
  from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward
30
+ from msprobe.pytorch.bench_functions.apply_adam import npu_apply_adam
31
+ from msprobe.pytorch.bench_functions.group_norm_silu import npu_group_norm_silu
32
+ from msprobe.pytorch.bench_functions.mish import npu_mish
33
+ from msprobe.pytorch.bench_functions.moe_gating_top_k_softmax import npu_moe_gating_top_k_softmax
34
+ from msprobe.pytorch.bench_functions.sort_v2 import npu_sort_v2
30
35
  from msprobe.pytorch.common.utils import logger
31
36
 
32
37
 
@@ -79,7 +84,8 @@ class Register(dict):
79
84
  npu_custom_functions = Register()
80
85
  npu_custom_functions([
81
86
  npu_apply_adam_w, npu_confusion_transpose, npu_fast_gelu, npu_layer_norm_eval, npu_linear, npu_fusion_attention,
82
- npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu, gpu_fusion_attention
87
+ npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu, gpu_fusion_attention, npu_apply_adam,
88
+ npu_group_norm_silu, npu_mish, npu_moe_gating_top_k_softmax, npu_sort_v2
83
89
  ])
84
90
 
85
91
  # register for npu custom backward bench functions
@@ -1911,4 +1911,5 @@ distributed:
1911
1911
  - all_to_all_single
1912
1912
  - all_to_all
1913
1913
  - all_gather_into_tensor
1914
- - reduce_scatter_tensor
1914
+ - reduce_scatter_tensor
1915
+ - batch_isend_irecv
@@ -57,6 +57,10 @@ class DistributedOPTemplate(HOOKModule):
57
57
  if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]:
58
58
  if handle and hasattr(handle, 'wait'):
59
59
  handle.wait()
60
+ if self.op_name_ == "batch_isend_irecv":
61
+ if isinstance(handle, list):
62
+ for req in handle:
63
+ req.wait()
60
64
  return handle
61
65
 
62
66