mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
  3. msprobe/README.md +27 -22
  4. msprobe/core/common/const.py +129 -60
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/inplace_ops.yaml +1 -0
  9. msprobe/core/common/utils.py +43 -33
  10. msprobe/core/compare/acc_compare.py +43 -74
  11. msprobe/core/compare/check.py +2 -6
  12. msprobe/core/compare/highlight.py +2 -0
  13. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  14. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  15. msprobe/core/compare/merge_result/merge_result.py +16 -9
  16. msprobe/core/compare/merge_result/utils.py +81 -0
  17. msprobe/core/compare/multiprocessing_compute.py +19 -12
  18. msprobe/core/compare/npy_compare.py +30 -12
  19. msprobe/core/compare/utils.py +30 -10
  20. msprobe/core/data_dump/api_registry.py +176 -0
  21. msprobe/core/data_dump/data_collector.py +58 -13
  22. msprobe/core/data_dump/data_processor/base.py +94 -10
  23. msprobe/core/data_dump/data_processor/factory.py +3 -0
  24. msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
  25. msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
  26. msprobe/core/data_dump/json_writer.py +61 -40
  27. msprobe/core/grad_probe/constant.py +1 -0
  28. msprobe/core/grad_probe/grad_compare.py +1 -1
  29. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  30. msprobe/docs/01.installation.md +27 -1
  31. msprobe/docs/02.config_introduction.md +27 -23
  32. msprobe/docs/03.config_examples.md +24 -0
  33. msprobe/docs/05.data_dump_PyTorch.md +103 -16
  34. msprobe/docs/06.data_dump_MindSpore.md +76 -32
  35. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  36. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  37. msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
  38. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  39. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  40. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  41. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  42. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  43. msprobe/docs/18.online_dispatch.md +1 -1
  44. msprobe/docs/19.monitor.md +332 -273
  45. msprobe/docs/21.visualization_PyTorch.md +42 -13
  46. msprobe/docs/22.visualization_MindSpore.md +43 -13
  47. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  48. msprobe/docs/27.dump_json_instruction.md +301 -27
  49. msprobe/docs/28.debugger_save_instruction.md +94 -0
  50. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  51. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  52. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  53. msprobe/docs/FAQ.md +3 -11
  54. msprobe/docs/img/compare_result.png +0 -0
  55. msprobe/docs/img/merge_result.png +0 -0
  56. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  57. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  58. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  59. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  60. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  61. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  63. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  64. msprobe/mindspore/__init__.py +4 -2
  65. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
  66. msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
  67. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  68. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  69. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  70. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  71. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  72. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  73. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
  74. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  75. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  76. msprobe/mindspore/common/const.py +61 -0
  77. msprobe/mindspore/common/utils.py +48 -18
  78. msprobe/mindspore/compare/ms_compare.py +27 -19
  79. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  80. msprobe/mindspore/debugger/debugger_config.py +31 -6
  81. msprobe/mindspore/debugger/precision_debugger.py +45 -14
  82. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  83. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  86. msprobe/mindspore/dump/jit_dump.py +21 -15
  87. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  88. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  89. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  90. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  91. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  92. msprobe/mindspore/grad_probe/global_context.py +2 -0
  93. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  94. msprobe/mindspore/grad_probe/hook.py +2 -4
  95. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  96. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  97. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  98. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  99. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  100. msprobe/mindspore/monitor/features.py +63 -0
  101. msprobe/mindspore/monitor/module_hook.py +873 -0
  102. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  103. msprobe/mindspore/monitor/utils.py +309 -0
  104. msprobe/mindspore/ms_config.py +8 -2
  105. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  106. msprobe/mindspore/service.py +114 -34
  107. msprobe/pytorch/__init__.py +0 -1
  108. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  109. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
  110. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  111. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  112. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  116. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  117. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  118. msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
  119. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
  120. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  121. msprobe/pytorch/common/utils.py +97 -4
  122. msprobe/pytorch/debugger/debugger_config.py +19 -9
  123. msprobe/pytorch/debugger/precision_debugger.py +24 -1
  124. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  125. msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
  126. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  127. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  132. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  133. msprobe/pytorch/function_factory.py +8 -2
  134. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  135. msprobe/pytorch/hook_module/api_register.py +131 -0
  136. msprobe/pytorch/hook_module/hook_module.py +19 -14
  137. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  138. msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
  139. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  140. msprobe/pytorch/monitor/csv2tb.py +18 -14
  141. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  142. msprobe/pytorch/monitor/module_hook.py +238 -193
  143. msprobe/pytorch/monitor/module_metric.py +9 -6
  144. msprobe/pytorch/monitor/optimizer_collect.py +100 -67
  145. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  146. msprobe/pytorch/monitor/utils.py +76 -44
  147. msprobe/pytorch/online_dispatch/compare.py +0 -2
  148. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  149. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  150. msprobe/pytorch/online_dispatch/utils.py +3 -0
  151. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  152. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  153. msprobe/pytorch/pt_config.py +30 -29
  154. msprobe/pytorch/service.py +114 -32
  155. msprobe/visualization/builder/graph_builder.py +75 -10
  156. msprobe/visualization/builder/msprobe_adapter.py +7 -6
  157. msprobe/visualization/compare/graph_comparator.py +42 -38
  158. msprobe/visualization/compare/mode_adapter.py +0 -19
  159. msprobe/visualization/graph/base_node.py +11 -3
  160. msprobe/visualization/graph/distributed_analyzer.py +71 -3
  161. msprobe/visualization/graph/graph.py +0 -11
  162. msprobe/visualization/graph/node_op.py +4 -3
  163. msprobe/visualization/graph_service.py +4 -5
  164. msprobe/visualization/utils.py +12 -35
  165. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
  166. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  167. msprobe/pytorch/hook_module/api_registry.py +0 -166
  168. msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
  169. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  171. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  172. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  173. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  174. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  175. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  176. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  177. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -19,8 +19,9 @@ import torch
19
19
  from msprobe.core.common.const import Const, FileCheckConst, MsgConst
20
20
  from msprobe.core.common.exceptions import MsprobeException
21
21
  from msprobe.core.common.file_utils import FileChecker
22
- from msprobe.core.common.utils import get_real_step_or_rank
22
+ from msprobe.core.common.utils import get_real_step_or_rank, check_init_step
23
23
  from msprobe.pytorch.common.log import logger
24
+ from msprobe.pytorch.common.utils import check_save_param
24
25
  from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
25
26
  from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper
26
27
  from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
@@ -158,6 +159,28 @@ class PrecisionDebugger:
158
159
  return
159
160
  cls._instance.gm.monitor(model)
160
161
 
162
+ @classmethod
163
+ def save(cls, variable, name, save_backward=True):
164
+ instance = cls._instance
165
+ if not instance:
166
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
167
+ if instance.task not in [Const.TENSOR, Const.STATISTICS] or instance.config.level != Const.LEVEL_DEBUG:
168
+ return
169
+ try:
170
+ check_save_param(variable, name, save_backward)
171
+ except ValueError:
172
+ return
173
+ instance.service.save(variable, name, save_backward)
174
+
175
+ @classmethod
176
+ def set_init_step(cls, step):
177
+ instance = cls._instance
178
+ if not instance:
179
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
180
+ check_init_step(step)
181
+ instance.service.init_step = step
182
+ instance.service.loop = 0
183
+
161
184
 
162
185
  def module_dump(module, dump_name):
163
186
  if not isinstance(module, torch.nn.Module):
@@ -17,7 +17,7 @@ import torch
17
17
  from msprobe.core.common.const import Const
18
18
  from msprobe.core.data_dump.scope import BaseScope
19
19
  from msprobe.pytorch.common.log import logger
20
- from msprobe.pytorch.hook_module.api_registry import api_register
20
+ from msprobe.pytorch.hook_module.api_register import get_api_register
21
21
 
22
22
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
23
23
 
@@ -26,13 +26,14 @@ class ModuleDumper:
26
26
  def __init__(self, service):
27
27
  self.service = service
28
28
  self.hook_handle_list = []
29
+ self.api_register = get_api_register()
29
30
 
30
31
  def start_module_dump(self, module, dump_name):
31
- api_register.api_originality()
32
+ self.api_register.restore_all_api()
32
33
  self.register_hook(module, dump_name)
33
34
 
34
35
  def stop_module_dump(self):
35
- api_register.api_modularity()
36
+ self.api_register.register_all_api()
36
37
  for hook_handle in self.hook_handle_list:
37
38
  if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
38
39
  hook_handle.remove()
@@ -16,14 +16,17 @@
16
16
  from functools import wraps
17
17
 
18
18
  import torch
19
+ from torch.utils.hooks import BackwardHook
20
+
19
21
  from msprobe.core.common.const import Const
22
+ from msprobe.core.common.decorator import recursion_depth_decorator
20
23
  from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
21
24
  from msprobe.pytorch.common.log import logger
22
- from torch.utils.checkpoint import checkpoint as origin_checkpoint
23
- from torch.utils.checkpoint import set_checkpoint_early_stop
24
- from torch.utils.hooks import BackwardHook
25
+ from msprobe.pytorch.common.utils import replace_last_occurrence, is_float8_tensor
25
26
 
26
27
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
28
+ if torch_version_above_or_equal_2:
29
+ from torch.utils.checkpoint import checkpoint as origin_checkpoint, set_checkpoint_early_stop
27
30
 
28
31
 
29
32
  def checkpoint_without_early_stop(*args, **kwargs):
@@ -32,7 +35,8 @@ def checkpoint_without_early_stop(*args, **kwargs):
32
35
 
33
36
 
34
37
  def replace_checkpoint():
35
- torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
38
+ if torch_version_above_or_equal_2:
39
+ torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
36
40
 
37
41
 
38
42
  class ModuleProcesser:
@@ -45,29 +49,8 @@ class ModuleProcesser:
45
49
  self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
46
50
  BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
47
51
  BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
48
- BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook)
49
52
  replace_checkpoint()
50
53
 
51
- @staticmethod
52
- def filter_tensor_and_tuple(func):
53
- @wraps(func)
54
- def wrap_by_filter_tensor_and_tuple(*args, **kwargs):
55
- # setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是解析非tensor数据的属性,对tensor属性挂hook
56
- # setup_output_hook定义为setup_output_hook(self, args),因此处理第二个位置参数,即*args[1]
57
- if not isinstance(args[1], (torch.Tensor, tuple)):
58
- for item_str in dir(args[1]):
59
- item = getattr(args[1], item_str)
60
- # 处理tensor或者只包含tensor的元组
61
- if isinstance(item, torch.Tensor) or \
62
- (isinstance(item, tuple) and all(isinstance(x, torch.Tensor) for x in item)):
63
- args_new = (args[0], item)
64
- result = func(*args_new, **kwargs)
65
- setattr(args[1], item_str, result)
66
- return args[1]
67
- return func(*args, **kwargs)
68
-
69
- return wrap_by_filter_tensor_and_tuple
70
-
71
54
  @staticmethod
72
55
  def clone_return_value(func):
73
56
  @wraps(func)
@@ -78,14 +61,15 @@ class ModuleProcesser:
78
61
  return clone_return_value_func
79
62
 
80
63
  @staticmethod
64
+ @recursion_depth_decorator("ModuleDump: ModuleProcesser.clone_if_tensor", max_depth=Const.DUMP_MAX_DEPTH)
81
65
  def clone_if_tensor(result):
82
- if isinstance(result, torch.Tensor):
66
+ if isinstance(result, torch.Tensor) and not is_float8_tensor(result):
83
67
  return result.clone()
84
- elif isinstance(result, tuple):
68
+ elif type(result) is tuple:
85
69
  return tuple(ModuleProcesser.clone_if_tensor(x) for x in result)
86
- elif isinstance(result, list):
70
+ elif type(result) is list:
87
71
  return list(ModuleProcesser.clone_if_tensor(x) for x in result)
88
- elif isinstance(result, dict):
72
+ elif type(result) is dict:
89
73
  return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()}
90
74
  else:
91
75
  return result
@@ -103,7 +87,7 @@ class ModuleProcesser:
103
87
  return hasattr(module, '_backward_hooks') and \
104
88
  len(module._backward_hooks) > 0 and \
105
89
  module._is_full_backward_hook is False
106
-
90
+
107
91
  @staticmethod
108
92
  def get_modules_and_names(models):
109
93
  modules_and_names_with_index = {}
@@ -129,9 +113,11 @@ class ModuleProcesser:
129
113
  for name, module in modules_and_names:
130
114
  if module == model:
131
115
  continue
116
+ if module.__class__.__name__ == "FullyShardedDataParallel":
117
+ continue
132
118
  module_index = (index + Const.SEP) if index != "-1" else ""
133
- prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index +
134
- name + Const.SEP + module.__class__.__name__ + Const.SEP)
119
+ prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index +
120
+ name + Const.SEP + module.__class__.__name__ + Const.SEP)
135
121
  pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = build_hook(
136
122
  BaseScope.Module_Type_Module,
137
123
  prefix_name
@@ -203,9 +189,9 @@ class ModuleProcesser:
203
189
  if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
204
190
  module.mindstudio_reserved_name = []
205
191
  module.mindstudio_reserved_name.append(full_name)
206
- forward_full_name = full_name.replace(Const.BACKWARD, Const.FORWARD)
207
- ModuleProcesser.module_node[full_name] = ModuleProcesser.module_node[forward_full_name].replace(
208
- Const.FORWARD, Const.BACKWARD) if ModuleProcesser.module_node[forward_full_name] else None
192
+ forward_full_name = replace_last_occurrence(full_name, Const.BACKWARD, Const.FORWARD)
193
+ ModuleProcesser.module_node[full_name] = replace_last_occurrence(
194
+ ModuleProcesser.module_node.get(forward_full_name), Const.FORWARD, Const.BACKWARD)
209
195
  ModuleProcesser.api_parent_node = None
210
196
  if self.scope:
211
197
  self.scope.begin_module(full_name)
@@ -16,7 +16,7 @@
16
16
 
17
17
  import torch
18
18
  from msprobe.core.common.exceptions import FreeBenchmarkException
19
- from msprobe.core.common.utils import recursion_depth_decorator
19
+ from msprobe.core.common.decorator import recursion_depth_decorator
20
20
  from msprobe.pytorch.free_benchmark.common.enums import DeviceType
21
21
 
22
22
 
@@ -16,7 +16,7 @@
16
16
  import math
17
17
 
18
18
  import torch
19
- from msprobe.core.common.utils import recursion_depth_decorator
19
+ from msprobe.core.common.decorator import recursion_depth_decorator
20
20
  from msprobe.pytorch.free_benchmark import logger
21
21
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
22
22
  from msprobe.pytorch.free_benchmark.common.utils import TorchC
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import torch
17
- from msprobe.core.common.utils import recursion_depth_decorator
17
+ from msprobe.core.common.decorator import recursion_depth_decorator
18
18
  from msprobe.pytorch.free_benchmark import logger
19
19
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
20
20
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -95,13 +95,13 @@ class AddNoiseLayer(NpuBaseLayer):
95
95
  except Exception:
96
96
  logger.warning_on_rank_0(
97
97
  f"[msprobe] Free Benchmark: For {self.api_name}, "
98
- f"when calculate maximun value, tensor is changed to float32."
98
+ f"when calculating the maximum value, the tensor is changed to float32."
99
99
  )
100
100
  max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
101
101
  if max_val < abs_tol:
102
102
  logger.warning_on_rank_0(
103
103
  f"[msprobe] Free Benchmark: For {self.api_name}, "
104
- f"Maximun value is less than the minimun threshold. Cancel add noise."
104
+ f"maximum value is less than the minimum threshold. Cancel adding noise."
105
105
  )
106
106
  return False
107
107
  return True
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import torch
17
- from msprobe.core.common.utils import recursion_depth_decorator
17
+ from msprobe.core.common.decorator import recursion_depth_decorator
18
18
  from msprobe.pytorch.free_benchmark import logger
19
19
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
20
20
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -100,13 +100,13 @@ class BitNoiseLayer(NpuBaseLayer):
100
100
  except Exception:
101
101
  logger.warning_on_rank_0(
102
102
  f"[msprobe] Free Benchmark: For {self.api_name}, "
103
- f"when calculate maximun value, tensor is changed to float32."
103
+ f"when calculate the maximum value, the tensor is changed to float32."
104
104
  )
105
105
  max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
106
106
  if max_val < abs_tol:
107
107
  logger.warning_on_rank_0(
108
108
  f"[msprobe] Free Benchmark: For {self.api_name}, "
109
- f"Maximun value is less than the minimun threshold. Cancel add noise."
109
+ f"maximum value is less than the minimum threshold. Cancel adding noise."
110
110
  )
111
111
  return False
112
112
  return True
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import torch
17
- from msprobe.core.common.utils import recursion_depth_decorator
17
+ from msprobe.core.common.decorator import recursion_depth_decorator
18
18
  from msprobe.pytorch.free_benchmark import logger
19
19
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
20
20
  from msprobe.pytorch.free_benchmark.common.params import DataParams
@@ -15,7 +15,7 @@
15
15
 
16
16
  import torch
17
17
  from msprobe.core.common.const import Const
18
- from msprobe.core.common.utils import recursion_depth_decorator
18
+ from msprobe.core.common.decorator import recursion_depth_decorator
19
19
  from msprobe.pytorch.free_benchmark import logger
20
20
  from msprobe.pytorch.free_benchmark.common.constant import CommonField
21
21
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -49,6 +49,6 @@ class CheckerHandler(FuzzHandler):
49
49
  except Exception as e:
50
50
  logger.warning_on_rank_0(
51
51
  f"[msprobe] Free Benchmark: For {self.params.api_name}, "
52
- f"when campare the result exception raise {e}"
52
+ f"when comparing the results, an exception is raised: {e}"
53
53
  )
54
54
  return data_params.original_result
@@ -27,6 +27,11 @@ from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotar
27
27
  from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \
28
28
  npu_scaled_masked_softmax_backward
29
29
  from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward
30
+ from msprobe.pytorch.bench_functions.apply_adam import npu_apply_adam
31
+ from msprobe.pytorch.bench_functions.group_norm_silu import npu_group_norm_silu
32
+ from msprobe.pytorch.bench_functions.mish import npu_mish
33
+ from msprobe.pytorch.bench_functions.moe_gating_top_k_softmax import npu_moe_gating_top_k_softmax
34
+ from msprobe.pytorch.bench_functions.sort_v2 import npu_sort_v2
30
35
  from msprobe.pytorch.common.utils import logger
31
36
 
32
37
 
@@ -65,7 +70,7 @@ class Register(dict):
65
70
 
66
71
  def add_register_item(key, value):
67
72
  if key in self._dict:
68
- logger.warning(f"{value.__name__} has been registered before, so we will overriden it.")
73
+ logger.warning(f"{value.__name__} has been registered before, so we will override it.")
69
74
  self[key] = value
70
75
  return value
71
76
 
@@ -79,7 +84,8 @@ class Register(dict):
79
84
  npu_custom_functions = Register()
80
85
  npu_custom_functions([
81
86
  npu_apply_adam_w, npu_confusion_transpose, npu_fast_gelu, npu_layer_norm_eval, npu_linear, npu_fusion_attention,
82
- npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu, gpu_fusion_attention
87
+ npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu, gpu_fusion_attention, npu_apply_adam,
88
+ npu_group_norm_silu, npu_mish, npu_moe_gating_top_k_softmax, npu_sort_v2
83
89
  ])
84
90
 
85
91
  # register for npu custom backward bench functions
@@ -46,7 +46,7 @@ class GradientMonitor:
46
46
  if not os.path.exists(self._output_path):
47
47
  create_directory(self._output_path)
48
48
  else:
49
- logger.warning(f"the file in {self._output_path} will be recoverd")
49
+ logger.warning(f"the file in {self._output_path} will be deleted")
50
50
  self._step = -1
51
51
  self._param2name = defaultdict(str)
52
52
 
@@ -97,7 +97,7 @@ class GradientMonitor:
97
97
  create_directory(output_dirpath)
98
98
  output_path = os.path.join(output_dirpath, f"grad_summary_{self._step}.csv")
99
99
  if os.path.exists(output_path):
100
- logger.warning(f"{output_path} will be recoverd")
100
+ logger.warning(f"{output_path} will be deleted")
101
101
  remove_path(output_path)
102
102
  header_result = GradStatCsv.generate_csv_header(self._level_adp, self._bounds)
103
103
  output_lines.insert(0, header_result)
@@ -0,0 +1,131 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ import os
18
+
19
+ import torch
20
+ import torch.distributed as dist
21
+
22
+ from msprobe.core.common.const import Const
23
+ from msprobe.core.data_dump.api_registry import ApiRegistry
24
+ from msprobe.pytorch.common.utils import (
25
+ torch_without_guard_version, is_gpu, torch_device_guard, parameter_adapter
26
+ )
27
+ from msprobe.pytorch.function_factory import npu_custom_functions
28
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
29
+
30
+
31
+ torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
32
+
33
+ _api_types = {
34
+ Const.PT_FRAMEWORK: {
35
+ Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
36
+ Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)),
37
+ Const.PT_API_TYPE_TORCH: (torch, (torch,)),
38
+ Const.PT_API_TYPE_VF: (torch._C._VariableFunctionsClass, (torch._VF,)),
39
+ Const.PT_API_TYPE_DIST: (dist, (dist, dist.distributed_c10d))
40
+ }
41
+ }
42
+ if not is_gpu:
43
+ import torch_npu
44
+ if torch_without_guard_version:
45
+ _api_types.get(Const.PT_FRAMEWORK).update(
46
+ {
47
+ Const.PT_API_TYPE_NPU: (torch.ops.npu, (torch_npu, torch.ops.npu))
48
+ }
49
+ )
50
+ else:
51
+ _api_types.get(Const.PT_FRAMEWORK).update(
52
+ {Const.PT_API_TYPE_NPU: (torch_npu._C._VariableFunctionsClass, (torch_npu,))}
53
+ )
54
+ _api_types.get(Const.PT_FRAMEWORK).update(
55
+ {
56
+ Const.PT_API_TYPE_NPU_DIST: (torch_npu.distributed, (torch_npu.distributed,
57
+ torch_npu.distributed.distributed_c10d))
58
+ }
59
+ )
60
+
61
+ _inner_used_api = {}
62
+ _supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
63
+ _cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
64
+
65
+
66
+ @parameter_adapter
67
+ def tensor_module_forward(module, *args, **kwargs):
68
+ return module.api_func(*args, **kwargs)
69
+
70
+
71
+ def dist_module_forward(module, *args, **kwargs):
72
+ handle = module.api_func(*args, **kwargs)
73
+ if kwargs.get("async_op") or module.api_name in ["isend", "irecv"]:
74
+ if handle and hasattr(handle, 'wait'):
75
+ handle.wait()
76
+ if module.api_name == "batch_isend_irecv":
77
+ if isinstance(handle, list):
78
+ for req in handle:
79
+ req.wait()
80
+ return handle
81
+
82
+
83
+ def npu_module_forward(module, *args, **kwargs):
84
+ if not module.need_hook:
85
+ if module.api_name not in npu_custom_functions:
86
+ raise Exception(f'There is not bench function {module.api_name}')
87
+ if module.device == Const.CUDA_LOWERCASE:
88
+ module.api_name = _cuda_func_mapping.get(module.api_name, module.api_name)
89
+ if module.device in [Const.CUDA_LOWERCASE, Const.CPU_LOWERCASE]:
90
+ return npu_custom_functions[module.api_name](*args, **kwargs)
91
+ return module.api_func(*args, **kwargs)
92
+
93
+
94
+ forward_methods = {
95
+ "Tensor": tensor_module_forward,
96
+ "Distributed": dist_module_forward,
97
+ "NPU": npu_module_forward
98
+ }
99
+
100
+
101
+ class ApiTemplate(HOOKModule):
102
+ def __init__(self, api_name, api_func, prefix, hook_build_func, need_hook=True, device=Const.CPU_LOWERCASE):
103
+ self.api_name = api_name
104
+ self.api_func = api_func
105
+ self.prefix = prefix
106
+ self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
107
+ self.need_hook = need_hook
108
+ self.device = device
109
+ if self.need_hook:
110
+ super().__init__(hook_build_func)
111
+ if prefix == Const.DIST_API_TYPE_PREFIX:
112
+ self.op_is_distributed = True
113
+
114
+ @torch_device_guard
115
+ def forward(self, *args, **kwargs):
116
+ exec_func = forward_methods.get(self.prefix)
117
+ exec_func = functools.partial(exec_func, self) if exec_func else self.api_func
118
+ return exec_func(*args, **kwargs)
119
+
120
+
121
+ api_register = None
122
+
123
+
124
+ def get_api_register(return_new=False):
125
+ if return_new:
126
+ return ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
127
+
128
+ global api_register
129
+ if api_register is None:
130
+ api_register = ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
131
+ return api_register
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,6 +21,8 @@ import torch
21
21
  import torch.nn as nn
22
22
  import torch.utils.hooks as full_hooks
23
23
 
24
+ from msprobe.pytorch.common.utils import is_float8_tensor
25
+
24
26
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
25
27
 
26
28
 
@@ -28,28 +30,27 @@ class HOOKModule(nn.Module):
28
30
  module_count = defaultdict(int)
29
31
  inner_stop_hook = {}
30
32
 
31
- def __init__(self, build_hook) -> None:
33
+ def __init__(self, hook_build_func) -> None:
32
34
  super(HOOKModule, self).__init__()
33
35
  self.has_overflow = False
34
- self.prefix = ""
35
36
  self.current_thread = threading.current_thread().ident
36
37
  if self.current_thread not in HOOKModule.inner_stop_hook:
37
38
  HOOKModule.inner_stop_hook[self.current_thread] = False
38
39
  self.stop_hook = HOOKModule.inner_stop_hook.get(self.current_thread, False)
39
40
 
40
41
  if not self.stop_hook:
41
- if hasattr(self, "prefix_op_name_"):
42
- self.prefix = self.prefix_op_name_
43
-
44
42
  self.forward_data_collected = False
45
- forward_pre_hook, forward_hook, backward_hook, _ = build_hook(self.prefix)
46
- if torch_version_above_or_equal_2:
47
- self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
48
- self.register_forward_hook(forward_hook, with_kwargs=True)
49
- else:
50
- self.register_forward_pre_hook(forward_pre_hook)
51
- self.register_forward_hook(forward_hook)
52
- self.register_backward_hook(backward_hook)
43
+
44
+ prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
45
+ if callable(hook_build_func):
46
+ forward_pre_hook, forward_hook, backward_hook, _ = hook_build_func(prefix)
47
+ if torch_version_above_or_equal_2:
48
+ self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
49
+ self.register_forward_hook(forward_hook, with_kwargs=True)
50
+ else:
51
+ self.register_forward_pre_hook(forward_pre_hook)
52
+ self.register_forward_hook(forward_hook)
53
+ self.register_backward_hook(backward_hook)
53
54
 
54
55
  def __call__(self, *args, **kwargs):
55
56
  changed = False
@@ -111,6 +112,10 @@ class HOOKModule(nn.Module):
111
112
  return result
112
113
  else:
113
114
  return result
115
+
116
+ if is_float8_tensor(var) or not (var.requires_grad and torch.is_grad_enabled()):
117
+ return result
118
+
114
119
  grad_fn = var.grad_fn
115
120
  if grad_fn is not None:
116
121
  for hook in non_full_backward_hooks:
@@ -32,8 +32,9 @@ def register_optimizer_hook(data_collector):
32
32
  def patch_clip_grad(func):
33
33
  def wrapper(*args, **kwargs):
34
34
  data_collector.optimizer_status = Const.CLIP_GRAD
35
- func(*args, **kwargs)
35
+ result = func(*args, **kwargs)
36
36
  data_collector.optimizer_status = Const.END_PREFIX + Const.CLIP_GRAD
37
+ return result
37
38
 
38
39
  return wrapper
39
40