mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
  3. msprobe/README.md +27 -22
  4. msprobe/core/common/const.py +129 -60
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/inplace_ops.yaml +1 -0
  9. msprobe/core/common/utils.py +43 -33
  10. msprobe/core/compare/acc_compare.py +43 -74
  11. msprobe/core/compare/check.py +2 -6
  12. msprobe/core/compare/highlight.py +2 -0
  13. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  14. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  15. msprobe/core/compare/merge_result/merge_result.py +16 -9
  16. msprobe/core/compare/merge_result/utils.py +81 -0
  17. msprobe/core/compare/multiprocessing_compute.py +19 -12
  18. msprobe/core/compare/npy_compare.py +30 -12
  19. msprobe/core/compare/utils.py +30 -10
  20. msprobe/core/data_dump/api_registry.py +176 -0
  21. msprobe/core/data_dump/data_collector.py +58 -13
  22. msprobe/core/data_dump/data_processor/base.py +94 -10
  23. msprobe/core/data_dump/data_processor/factory.py +3 -0
  24. msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
  25. msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
  26. msprobe/core/data_dump/json_writer.py +61 -40
  27. msprobe/core/grad_probe/constant.py +1 -0
  28. msprobe/core/grad_probe/grad_compare.py +1 -1
  29. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  30. msprobe/docs/01.installation.md +27 -1
  31. msprobe/docs/02.config_introduction.md +27 -23
  32. msprobe/docs/03.config_examples.md +24 -0
  33. msprobe/docs/05.data_dump_PyTorch.md +103 -16
  34. msprobe/docs/06.data_dump_MindSpore.md +76 -32
  35. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  36. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  37. msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
  38. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  39. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  40. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  41. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  42. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  43. msprobe/docs/18.online_dispatch.md +1 -1
  44. msprobe/docs/19.monitor.md +332 -273
  45. msprobe/docs/21.visualization_PyTorch.md +42 -13
  46. msprobe/docs/22.visualization_MindSpore.md +43 -13
  47. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  48. msprobe/docs/27.dump_json_instruction.md +301 -27
  49. msprobe/docs/28.debugger_save_instruction.md +94 -0
  50. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  51. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  52. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  53. msprobe/docs/FAQ.md +3 -11
  54. msprobe/docs/img/compare_result.png +0 -0
  55. msprobe/docs/img/merge_result.png +0 -0
  56. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  57. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  58. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  59. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  60. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  61. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  63. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  64. msprobe/mindspore/__init__.py +4 -2
  65. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
  66. msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
  67. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  68. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  69. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  70. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  71. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  72. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  73. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
  74. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  75. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  76. msprobe/mindspore/common/const.py +61 -0
  77. msprobe/mindspore/common/utils.py +48 -18
  78. msprobe/mindspore/compare/ms_compare.py +27 -19
  79. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  80. msprobe/mindspore/debugger/debugger_config.py +31 -6
  81. msprobe/mindspore/debugger/precision_debugger.py +45 -14
  82. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  83. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  86. msprobe/mindspore/dump/jit_dump.py +21 -15
  87. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  88. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  89. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  90. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  91. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  92. msprobe/mindspore/grad_probe/global_context.py +2 -0
  93. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  94. msprobe/mindspore/grad_probe/hook.py +2 -4
  95. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  96. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  97. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  98. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  99. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  100. msprobe/mindspore/monitor/features.py +63 -0
  101. msprobe/mindspore/monitor/module_hook.py +873 -0
  102. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  103. msprobe/mindspore/monitor/utils.py +309 -0
  104. msprobe/mindspore/ms_config.py +8 -2
  105. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  106. msprobe/mindspore/service.py +114 -34
  107. msprobe/pytorch/__init__.py +0 -1
  108. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  109. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
  110. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  111. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  112. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  116. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  117. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  118. msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
  119. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
  120. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  121. msprobe/pytorch/common/utils.py +97 -4
  122. msprobe/pytorch/debugger/debugger_config.py +19 -9
  123. msprobe/pytorch/debugger/precision_debugger.py +24 -1
  124. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  125. msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
  126. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  127. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  132. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  133. msprobe/pytorch/function_factory.py +8 -2
  134. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  135. msprobe/pytorch/hook_module/api_register.py +131 -0
  136. msprobe/pytorch/hook_module/hook_module.py +19 -14
  137. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  138. msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
  139. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  140. msprobe/pytorch/monitor/csv2tb.py +18 -14
  141. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  142. msprobe/pytorch/monitor/module_hook.py +238 -193
  143. msprobe/pytorch/monitor/module_metric.py +9 -6
  144. msprobe/pytorch/monitor/optimizer_collect.py +100 -67
  145. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  146. msprobe/pytorch/monitor/utils.py +76 -44
  147. msprobe/pytorch/online_dispatch/compare.py +0 -2
  148. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  149. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  150. msprobe/pytorch/online_dispatch/utils.py +3 -0
  151. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  152. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  153. msprobe/pytorch/pt_config.py +30 -29
  154. msprobe/pytorch/service.py +114 -32
  155. msprobe/visualization/builder/graph_builder.py +75 -10
  156. msprobe/visualization/builder/msprobe_adapter.py +7 -6
  157. msprobe/visualization/compare/graph_comparator.py +42 -38
  158. msprobe/visualization/compare/mode_adapter.py +0 -19
  159. msprobe/visualization/graph/base_node.py +11 -3
  160. msprobe/visualization/graph/distributed_analyzer.py +71 -3
  161. msprobe/visualization/graph/graph.py +0 -11
  162. msprobe/visualization/graph/node_op.py +4 -3
  163. msprobe/visualization/graph_service.py +4 -5
  164. msprobe/visualization/utils.py +12 -35
  165. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
  166. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  167. msprobe/pytorch/hook_module/api_registry.py +0 -166
  168. msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
  169. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  171. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  172. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  173. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  174. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  175. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  176. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  177. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
4
2
  # All rights reserved.
5
3
  #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
5
  # you may not use this file except in compliance with the License.
8
6
  # You may obtain a copy of the License at
9
7
  #
@@ -18,8 +16,8 @@
18
16
  import os
19
17
  from collections import namedtuple
20
18
  import re
21
- import torch
22
19
 
20
+ import torch
23
21
  try:
24
22
  import torch_npu
25
23
  except ImportError:
@@ -33,11 +31,9 @@ from msprobe.core.common.const import FileCheckConst, Const, CompareConst
33
31
  from msprobe.core.common.file_utils import FileChecker
34
32
  from msprobe.core.common.log import logger
35
33
  from msprobe.core.common.utils import CompareException
34
+ from msprobe.pytorch.hook_module.api_register import ApiTemplate, get_api_register
36
35
  from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
37
- from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
38
- from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
39
- from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
40
- from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
36
+
41
37
 
42
38
  hf_32_standard_api = ["conv1d", "conv2d"]
43
39
  not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
@@ -108,17 +104,30 @@ def exec_api(exec_params):
108
104
  kwargs = exec_params.kwargs
109
105
  is_autocast = exec_params.is_autocast
110
106
  autocast_dtype = exec_params.autocast_dtype
111
-
112
- if api_type == "Functional":
113
- torch_api = FunctionalOPTemplate(api_name, str, False)
114
- if api_type == "Tensor":
115
- torch_api = TensorOPTemplate(api_name, str, False)
116
- if api_type == "Torch":
117
- torch_api = TorchOPTemplate(api_name, str, False)
118
- if api_type == "Aten":
107
+ out = None
108
+
109
+ prefix_map = Const.API_DATA_PREFIX.get(Const.PT_FRAMEWORK, {})
110
+ if not prefix_map or api_type not in prefix_map.values() or \
111
+ api_type not in (
112
+ Const.FUNCTIONAL_API_TYPE_PREFIX,
113
+ Const.TENSOR_API_TYPE_PREFIX,
114
+ Const.TORCH_API_TYPE_PREFIX,
115
+ Const.ATEN_API_TYPE_PREFIX,
116
+ Const.NPU_API_TYPE_PREFIX
117
+ ):
118
+ return out
119
+
120
+ if api_type == Const.ATEN_API_TYPE_PREFIX:
119
121
  torch_api = AtenOPTemplate(api_name, None, False)
120
- if api_type == "NPU":
121
- torch_api = NpuOPTemplate(api_name, None, False, device)
122
+ else:
123
+ api_register = get_api_register()
124
+ api_register.initialize_hook(None)
125
+ api_func_type = list(prefix_map.keys())[list(prefix_map.values()).index(api_type)]
126
+ api_func = api_register.ori_api_attr.get(Const.PT_FRAMEWORK + Const.SEP + api_func_type, {}).get(api_name)
127
+ if api_func is None:
128
+ return out
129
+
130
+ torch_api = ApiTemplate(api_name, api_func, api_type, None, need_hook=False, device=device)
122
131
  if is_autocast:
123
132
  with autocast(dtype=autocast_dtype):
124
133
  out = torch_api.forward(*args, **kwargs)
@@ -27,6 +27,7 @@ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import T
27
27
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
28
28
  from msprobe.core.common.file_utils import remove_path
29
29
  from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl
30
+ from msprobe.core.common.decorator import recursion_depth_decorator
30
31
 
31
32
  BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
32
33
 
@@ -168,11 +169,12 @@ class ATTL:
168
169
  return buffer
169
170
 
170
171
 
172
+ @recursion_depth_decorator("move2device_exec")
171
173
  def move2device_exec(obj, device):
172
174
  if isinstance(obj, (tuple, list)):
173
175
  data_list = [move2device_exec(val, device) for val in obj]
174
176
  return data_list if isinstance(obj, list) else tuple(data_list)
175
- if isinstance(obj, dict):
177
+ if isinstance(obj, dict):
176
178
  return {key: move2device_exec(val, device) for key, val in obj.items()}
177
179
  elif isinstance(obj, torch.Tensor):
178
180
  obj = obj.detach()
@@ -0,0 +1,215 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import namedtuple
17
+ import torch
18
+
19
+
20
+ VarParams = namedtuple('VarParams', ['var', 'lr_t', 'm_t', 'beta1_broad', 'grad', 'epsilon', 'v_t'])
21
+
22
+
23
+ def _output_m_compute(m, beta1_broad, grad):
24
+ """
25
+ _output_m_compute
26
+ do compute m_t = m + (beta1 - 1) * (m - grad)
27
+ """
28
+ input_dtype = m.dtype
29
+
30
+ sneg_one = torch.ones((1), dtype=input_dtype) * -1
31
+ sneg_one = sneg_one.to(beta1_broad.device)
32
+
33
+ # `formula; beta1 -1`
34
+ vsub_beta1_1 = torch.add(beta1_broad, sneg_one)
35
+
36
+ # `formula; m - grad`
37
+ vsub_m_grad = torch.sub(m, grad)
38
+
39
+ # `formula; (beta1 - 1) * (m - grad)`
40
+ vmul_m = torch.mul(vsub_beta1_1, vsub_m_grad)
41
+
42
+ # `formula; m_t = m + (beta1 - 1) * (m - grad)`
43
+ m_t = torch.add(m, vmul_m)
44
+
45
+ return m_t
46
+
47
+
48
+ def _output_v_compute(v, beta2, grad):
49
+ """
50
+ _output_v_compute
51
+ do compute v_t = v + (1 - beta2)*(grad*grad -v)
52
+ """
53
+ input_dtype = v.dtype
54
+
55
+ sneg_one = torch.ones((1), dtype=input_dtype) * -1
56
+
57
+ # `formula; broadcast beta2 to vector`
58
+ beta2_tensor = torch.tensor(beta2, dtype=input_dtype)
59
+ beta2_broad = beta2_tensor.expand_as(v)
60
+
61
+ # `formula; beta2 - 1`
62
+ vsub_beta2_1 = torch.add(beta2_broad, sneg_one)
63
+ vsub_beta2_1 = vsub_beta2_1.to(v.device)
64
+
65
+ # `formula; grad * grad`
66
+ vmul_grad_grad = torch.mul(grad, grad)
67
+
68
+ # `formula; (v - grad*grad)`
69
+ vsub_v_grad = torch.sub(v, vmul_grad_grad)
70
+
71
+ # `formula; (beta2 -1) * (v - grad * grad)`
72
+ vmul_grad = torch.mul(vsub_beta2_1, vsub_v_grad)
73
+
74
+ # `formula; v_t = v + (beta2 - 1) * (v - grad * grad)`
75
+ v_t = torch.add(v, vmul_grad)
76
+
77
+ return v_t
78
+
79
+
80
+ def _inner_lr_compute(lr, beta2_power, beta1_power, compute_shape_tensor):
81
+ """
82
+ _inner_lr_compute
83
+ `formula; lr_t = learning_rate * (sqrt(1-beta2_power)) / (1 - beta1_power)`
84
+ """
85
+
86
+ input_dtype = compute_shape_tensor.dtype
87
+
88
+ s_one = torch.ones((1), dtype=input_dtype)
89
+
90
+ s_neg_one = torch.ones((1), dtype=input_dtype) * -1
91
+
92
+ # `formula; (1 - beta2_power)`
93
+ v_neg_beta2_power = torch.mul(beta2_power, s_neg_one)
94
+ v_add_beta2_power = torch.add(v_neg_beta2_power, s_one)
95
+
96
+ # `formula; sqrt(1 - beta2_power)`
97
+ v_sqrt_beta2_power = torch.sqrt(v_add_beta2_power)
98
+
99
+ # `formula; (1 - beta1_power)`
100
+ v_neg_beta1_power = torch.mul(beta1_power, s_neg_one)
101
+ v_add_beta1_power = torch.add(v_neg_beta1_power, s_one)
102
+
103
+ # `formula; learning_rate * (sqrt(1-beta2_power)`
104
+ res = torch.mul(lr, v_sqrt_beta2_power)
105
+
106
+ # `formula; learning_rate*(sqrt(1-beta2_power))/(1-beta1_power)`
107
+ res = torch.div(res, v_add_beta1_power)
108
+ return res.expand_as(compute_shape_tensor)
109
+
110
+
111
+ def _inner_eps_add_sqrt_vt_compute(epsilon, v_t):
112
+ """
113
+ (epsilon + sqrt(v_t) )
114
+ """
115
+ # `formula; sqrt(v_t)`
116
+ sqrt_vt = torch.sqrt(v_t)
117
+
118
+ # `formula; broadcast epsilon to vector`
119
+ input_dtype = v_t.dtype
120
+ epsilon_tensor = torch.tensor(epsilon, dtype=input_dtype)
121
+ epsilon_broad = epsilon_tensor.expand_as(v_t)
122
+ epsilon_broad = epsilon_broad.to(sqrt_vt.device)
123
+
124
+ # `formula; epsilon + sqrt(v_t)`
125
+ v_add_sqrt_v = torch.add(sqrt_vt, epsilon_broad)
126
+
127
+ return v_add_sqrt_v
128
+
129
+
130
+ def _output_var_t_compute_use_nesterov(varparams):
131
+ """
132
+ _output_var_t_compute_use_nesterov
133
+ `formula; var_t = var - lr_t * (m_t * beta1 + (1 - beta1) * grad) / (epsilon + sqrt(v_t))`
134
+ `formula; var_t = var - lr_t * (m_t * beta1 + (1 - beta1) * grad) / (epsilon + sqrt(v_t))`
135
+ """
136
+ var = varparams.var
137
+ lr_t = varparams.lr_t
138
+ m_t = varparams.m_t
139
+ beta1_broad = varparams.beta1_broad
140
+ grad = varparams.grad
141
+ epsilon = varparams.epsilon
142
+ v_t = varparams.v_t
143
+
144
+ input_dtype = var.dtype
145
+
146
+ s_one = torch.ones((1), dtype=input_dtype)
147
+
148
+ s_neg_one = torch.ones((1), dtype=input_dtype) * -1
149
+
150
+ # `formula; m_t * beta1`
151
+ v_muls_mt_beta1 = torch.mul(m_t, beta1_broad)
152
+
153
+ # `formula; 1 -beta1`
154
+ v_neg_beta1 = torch.mul(beta1_broad, s_neg_one)
155
+ vsub_1_beta1 = torch.add(v_neg_beta1, s_one)
156
+
157
+ # `formula; (1-beta1)* grad`
158
+ v_mul_grad = torch.mul(vsub_1_beta1, grad)
159
+
160
+ # `formula; (m_t*beta1 + (1 - beta1)*grad)`
161
+ v_div_left = torch.add(v_muls_mt_beta1, v_mul_grad)
162
+
163
+ # `formula; lr_t * (m_t*beta1 + (1 - beta1) * grad)`
164
+ # broadcast lr_t to vector
165
+
166
+ lrt_broad = lr_t.expand_as(var)
167
+ v_mul_left = torch.mul(lrt_broad, v_div_left)
168
+
169
+ # `formula; (epsilon + sqrt(v_t))`
170
+ v_add_sqrt_v = _inner_eps_add_sqrt_vt_compute(epsilon, v_t)
171
+
172
+ # `formula; lr_t * (m_t*beta1 + (1-beta1)*grad / (epsilon + sqrt(v_t))`
173
+ v_div_res = torch.div(v_mul_left, v_add_sqrt_v)
174
+
175
+ # `formula; var - lr_t * (m_t*beta1 + (1-beta1)*grad) / (epsilon + sqrt(v_t))`
176
+ v_t = torch.sub(var, v_div_res)
177
+
178
+ return v_t
179
+
180
+
181
+ def _output_var_t_compute(var, lr_t, m_t, epsilon, v_t):
182
+ """
183
+ _output_var_t_compute
184
+ `var_t = var - lr_t * m_t / (epsilon + sqrt(v_t))`
185
+ """
186
+ # `formula; lr_t * m_t`
187
+ lr_t = lr_t.to(m_t.device)
188
+ v_mul_left = torch.mul(lr_t, m_t)
189
+
190
+ # `formula; (epsilon + sqrt(v_t))`
191
+ v_add_sqrt_v = _inner_eps_add_sqrt_vt_compute(epsilon, v_t)
192
+
193
+ # `formula; lr_t * m_t /(epsilon + sqrt(v_t))`
194
+ v_div_res = torch.div(v_mul_left, v_add_sqrt_v)
195
+
196
+ # `formula; var - lr_t * m_t / (epsilon + sqrt(v_t))`
197
+ v_t = torch.sub(var, v_div_res)
198
+
199
+ return v_t
200
+
201
+
202
+ def npu_apply_adam(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, use_locking, use_nesterov, out):
203
+ var, m, v = out
204
+ input_dtype = m.dtype
205
+ beta1_tensor = torch.tensor(beta1, dtype=input_dtype).to(m.device)
206
+ beta1_broad = beta1_tensor.expand_as(m)
207
+ m_t = _output_m_compute(m, beta1_broad, grad)
208
+ v_t = _output_v_compute(v, beta2, grad)
209
+ lr_t = _inner_lr_compute(lr, beta2_power, beta1_power, grad)
210
+ if use_nesterov:
211
+ var_params = VarParams(var, lr_t, m_t, beta1_broad, grad, epsilon, v_t)
212
+ var_t = _output_var_t_compute_use_nesterov(var_params)
213
+ else:
214
+ var_t = _output_var_t_compute(var, lr_t, m_t, epsilon, v_t)
215
+ return var_t, m_t, v_t
@@ -0,0 +1,27 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
18
+
19
+ def npu_group_norm_silu(x, gama, beta, group, eps):
20
+ if len(x.shape) != 4:
21
+ raise ValueError("x shape should be (N, C, H, W)")
22
+ res = torch.ops.aten.native_group_norm(x, gama, beta, x.shape[0], x.shape[1], x.shape[2] * x.shape[3], group, eps)
23
+ res = list(res)
24
+ if not res:
25
+ raise ValueError("run native_group_norm failed")
26
+ res[0] = torch.nn.functional.silu(res[0])
27
+ return res
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,7 +13,9 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from msprobe.pytorch.parse_tool import cli
16
+ import torch
17
17
 
18
- if __name__ == '__main__':
19
- cli.parse()
18
+
19
+ def npu_mish(x):
20
+ mish = torch.nn.Mish()
21
+ return mish(x)
@@ -0,0 +1,50 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import numpy as np
18
+
19
+
20
+ def softmax_func(x, axis=None):
21
+ x = x.float()
22
+ x_max = x.max(dim=axis, keepdims=True).values
23
+ x_sub = x - x_max
24
+ y = torch.exp(x_sub)
25
+ x_sum = y.sum(dim=axis, keepdims=True)
26
+ ans = 0 if (x_sum == 0).any() else y / x_sum
27
+ return ans
28
+
29
+
30
+ def npu_moe_gating_top_k_softmax(x, finished_optional, k):
31
+ input_dtype = x.dtype
32
+ if x.dim() < 1:
33
+ raise ValueError("Input x must have at least 1 dimensions.")
34
+ num_expert = x.shape[-1]
35
+ softmax = softmax_func(x, -1)
36
+ softmax = softmax.to(input_dtype)
37
+ expert_idx = torch.argsort(-softmax, dim=-1, stable=True)
38
+ expert_idx = expert_idx[:, :k]
39
+ y = torch.gather(softmax, index=expert_idx, dim=-1)
40
+ if finished_optional is not None:
41
+ if finished_optional.dim() < 1:
42
+ raise ValueError("Finished_optional must have at least 1 dimensions.")
43
+ finished_optional = finished_optional.view(finished_optional.shape[0], 1)
44
+ finished_optional = finished_optional.expand(-1, k)
45
+ expert_idx = torch.where(finished_optional, num_expert, expert_idx)
46
+ if y.dim() < 2:
47
+ raise ValueError("Variable y must have at least 2 dimensions.")
48
+ row_idx = torch.arange(y.shape[0] * y.shape[1]).reshape(y.shape[1], y.shape[0]).t()
49
+
50
+ return y, expert_idx, row_idx
@@ -0,0 +1,21 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
18
+
19
+ def npu_sort_v2(x, dim=-1, descending=False, out=None):
20
+ y, _ = torch.sort(x, dim=dim, descending=descending)
21
+ return y
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,6 +18,7 @@ import os
18
18
  import pickle
19
19
  import random
20
20
  import stat
21
+ import inspect
21
22
  from functools import wraps
22
23
 
23
24
  import numpy as np
@@ -27,7 +28,7 @@ from msprobe.core.common.exceptions import DistributedNotInitializedError
27
28
  from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
28
29
  check_file_or_directory_path, check_path_before_create, FileOpen)
29
30
  from msprobe.core.common.log import logger
30
- from msprobe.core.common.utils import check_seed_all
31
+ from msprobe.core.common.utils import check_seed_all, is_save_variable_valid
31
32
  from packaging import version
32
33
 
33
34
  try:
@@ -56,7 +57,7 @@ def parameter_adapter(func):
56
57
 
57
58
  @wraps(func)
58
59
  def inner(self, *args, **kwargs):
59
- if self.op_name_ == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor):
60
+ if self.api_name == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor):
60
61
  input_tensor = args[0]
61
62
  indices = args[1]
62
63
  if indices.dtype == torch.uint8:
@@ -76,7 +77,7 @@ def parameter_adapter(func):
76
77
  else:
77
78
  res = [input_tensor[tensor_index] for tensor_index in indices]
78
79
  return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0)
79
- if self.op_name_ == "__eq__" and len(args) > 1 and args[1] is None:
80
+ if self.api_name == "__eq__" and len(args) > 1 and args[1] is None:
80
81
  return False
81
82
  return func(self, *args, **kwargs)
82
83
 
@@ -260,6 +261,10 @@ class Const:
260
261
  NPU = 'NPU'
261
262
  DISTRIBUTED = 'Distributed'
262
263
 
264
+ HIFLOAT8_TYPE = "torch_npu.HiFloat8Tensor"
265
+ FLOAT8_E5M2_TYPE = "torch.float8_e5m2"
266
+ FLOAT8_E4M3FN_TYPE = "torch.float8_e4m3fn"
267
+
263
268
  RAISE_PRECISION = {
264
269
  torch.float16: torch.float32,
265
270
  torch.bfloat16: torch.float32,
@@ -402,3 +407,91 @@ def load_api_data(api_data_bytes):
402
407
  except Exception as e:
403
408
  raise RuntimeError(f"load api_data from bytes failed") from e
404
409
  return buffer
410
+
411
+
412
+ def is_recomputation():
413
+ """Check if the current operation is in the re-computation phase.
414
+
415
+ This function inspects the current call stack to indicate whether the current operation is in the
416
+ re-computation phase. We use a blacklist mechanism, now supported megatron and mindspeed framework.
417
+ megatron: The 'backward' function is called by the 'torch/autograd/function.py' file.
418
+ mindspeed: The 'checkpoint_function_backward' function is called by the 'torch/autograd/function.py'
419
+ file or the custom module(use CheckpointWithoutOutput) with the 'recompute_fn' function is executed within the
420
+ 'torch/utils/checkpoint.py' file.
421
+
422
+ Returns:
423
+ bool: True if in the re-computation phase, False otherwise.
424
+ """
425
+ backward_function_indices = []
426
+ try:
427
+ call_stack = inspect.stack()
428
+ except Exception as e:
429
+ logger.warning(f"Failed to capture stack trace, recomputation validation may be incorrect, error info: {e}.")
430
+ return False
431
+
432
+ # Identify the function 'backward' is being executed within the 'torch/_tensor.py' file.
433
+ for frame_info in call_stack:
434
+ if frame_info.function == "recompute_fn" and frame_info.filename.endswith('torch/utils/checkpoint.py'):
435
+ del call_stack
436
+ return True
437
+
438
+ # Identify indices in the call stack where the specific function is being executed
439
+ for idx, frame_info in enumerate(call_stack):
440
+ if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward':
441
+ backward_function_indices.append(idx)
442
+
443
+ # Check if the execution is within 'torch/autograd/function.py' file
444
+ for idx in backward_function_indices:
445
+ # The Megatron and MindSpeed L0&L1 scenes
446
+ if idx + 1 < len(call_stack) and call_stack[idx + 1].filename.endswith('torch/autograd/function.py'):
447
+ del call_stack
448
+ return True
449
+ # The latest MindSpeed L2 and ModelLink scenes
450
+ if idx + 2 < len(call_stack) and call_stack[idx + 2].filename.endswith('torch/autograd/function.py'):
451
+ del call_stack
452
+ return True
453
+
454
+ del call_stack
455
+ return False
456
+
457
+
458
+ def check_save_param(variable, name, save_backward):
459
+ # try catch this api to skip invalid call
460
+ valid_data_types = tuple([torch.Tensor, int, float, str])
461
+ if not is_save_variable_valid(variable, valid_data_types):
462
+ valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list)
463
+ logger.warning("PrecisionDebugger.save variable type not valid, "
464
+ f"should be one of {valid_data_types_with_nested_types}"
465
+ "Skip current save process.")
466
+ raise ValueError
467
+ if not isinstance(name, str):
468
+ logger.warning("PrecisionDebugger.save name not valid, "
469
+ "should be string. "
470
+ "skip current save process.")
471
+ raise ValueError
472
+ if not isinstance(save_backward, bool):
473
+ logger.warning("PrecisionDebugger.save_backward name not valid, "
474
+ "should be bool. "
475
+ "Skip current save process.")
476
+ raise ValueError
477
+
478
+
479
+ def replace_last_occurrence(text, old, new):
480
+ if text is None:
481
+ return text
482
+ index = text.rfind(old)
483
+ if index != -1:
484
+ return text[:index] + text[index:].replace(old, new, 1)
485
+ return text
486
+
487
+
488
+ def is_hifloat8_tensor(tensor):
489
+ if not is_gpu and hasattr(torch_npu, "HiFloat8Tensor") and isinstance(tensor, torch_npu.HiFloat8Tensor):
490
+ return True
491
+ return False
492
+
493
+
494
+ def is_float8_tensor(tensor):
495
+ if str(tensor.dtype) in [Const.FLOAT8_E5M2_TYPE, Const.FLOAT8_E4M3FN_TYPE]:
496
+ return True
497
+ return is_hifloat8_tensor(tensor)
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -26,7 +26,7 @@ class DebuggerConfig:
26
26
  self.task = task or common_config.task or Const.STATISTICS
27
27
  self.rank = common_config.rank if common_config.rank else []
28
28
  self.step = common_config.step if common_config.step else []
29
- self.level = level or common_config.level or "L1"
29
+ self.level = level or common_config.level or Const.LEVEL_L1
30
30
  self.enable_dataloader = common_config.enable_dataloader
31
31
  self.scope = task_config.scope if task_config.scope else []
32
32
  self.list = task_config.list if task_config.list else []
@@ -36,10 +36,6 @@ class DebuggerConfig:
36
36
  self.framework = Const.PT_FRAMEWORK
37
37
  self.async_dump = common_config.async_dump if common_config.async_dump else False
38
38
 
39
- if self.level == Const.LEVEL_L2:
40
- self.is_backward_kernel_dump = False
41
- self._check_and_adjust_config_with_l2()
42
-
43
39
  if self.task == Const.FREE_BENCHMARK:
44
40
  self.fuzz_device = task_config.fuzz_device
45
41
  self.handler_type = task_config.handler_type
@@ -65,6 +61,10 @@ class DebuggerConfig:
65
61
 
66
62
  self.check()
67
63
 
64
+ if self.level == Const.LEVEL_L2:
65
+ self.is_backward_kernel_dump = False
66
+ self._check_and_adjust_config_with_l2()
67
+
68
68
  def check_kwargs(self):
69
69
  if self.task and self.task not in Const.TASK_LIST:
70
70
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
@@ -78,6 +78,16 @@ class DebuggerConfig:
78
78
  if not isinstance(self.async_dump, bool):
79
79
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
80
80
  f"The parameters async_dump should be bool.")
81
+ if self.async_dump and self.task == Const.TENSOR and not self.list:
82
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
83
+ f"The parameters async_dump is true in tensor task, the parameters list cannot be "
84
+ f"empty.")
85
+ if self.task == Const.STRUCTURE and self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
86
+ logger.warning_on_rank_0(
87
+ f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
88
+ f"If not, the default level is {Const.LEVEL_MIX}."
89
+ )
90
+ self.level = Const.LEVEL_MIX
81
91
 
82
92
  def check(self):
83
93
  self.check_kwargs()
@@ -93,10 +103,10 @@ class DebuggerConfig:
93
103
  logger.error_on_rank_0(
94
104
  f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' parameter.")
95
105
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
96
-
106
+
97
107
  instance.model = start_model if start_model is not None else instance.model
98
108
  if isinstance(instance.model, torch.nn.Module):
99
- return
109
+ return
100
110
 
101
111
  error_model = None
102
112
  if isinstance(instance.model, (list, tuple)):
@@ -108,7 +118,7 @@ class DebuggerConfig:
108
118
  error_model = instance.model
109
119
 
110
120
  if error_model is not None:
111
- error_info = (f"The 'model' parameter must be a torch.nn.Moudle or list[torch.nn.Moudle] "
121
+ error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] "
112
122
  f"type, currently there is a {type(error_model)} type.")
113
123
  raise MsprobeException(
114
124
  MsprobeException.INVALID_PARAM_ERROR, error_info)