mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
  3. msprobe/README.md +25 -20
  4. msprobe/core/common/const.py +110 -66
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/utils.py +30 -34
  9. msprobe/core/compare/acc_compare.py +43 -74
  10. msprobe/core/compare/check.py +2 -6
  11. msprobe/core/compare/highlight.py +2 -0
  12. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  13. msprobe/core/compare/merge_result/merge_result.py +8 -2
  14. msprobe/core/compare/multiprocessing_compute.py +19 -12
  15. msprobe/core/compare/npy_compare.py +30 -12
  16. msprobe/core/compare/utils.py +20 -10
  17. msprobe/core/data_dump/api_registry.py +176 -0
  18. msprobe/core/data_dump/data_processor/base.py +2 -2
  19. msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
  20. msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
  21. msprobe/core/data_dump/json_writer.py +38 -35
  22. msprobe/core/grad_probe/constant.py +1 -0
  23. msprobe/core/grad_probe/grad_compare.py +1 -1
  24. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  25. msprobe/docs/01.installation.md +2 -1
  26. msprobe/docs/02.config_introduction.md +17 -15
  27. msprobe/docs/05.data_dump_PyTorch.md +70 -2
  28. msprobe/docs/06.data_dump_MindSpore.md +33 -12
  29. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  30. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  31. msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
  32. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  33. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  34. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  35. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  36. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  37. msprobe/docs/18.online_dispatch.md +1 -1
  38. msprobe/docs/19.monitor.md +124 -62
  39. msprobe/docs/21.visualization_PyTorch.md +32 -13
  40. msprobe/docs/22.visualization_MindSpore.md +32 -13
  41. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  42. msprobe/docs/27.dump_json_instruction.md +278 -8
  43. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  44. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  45. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  46. msprobe/docs/FAQ.md +3 -11
  47. msprobe/docs/img/compare_result.png +0 -0
  48. msprobe/docs/img/merge_result.png +0 -0
  49. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  50. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  51. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  52. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  53. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  54. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  55. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  56. msprobe/mindspore/__init__.py +4 -3
  57. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
  58. msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
  59. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  60. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  61. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  62. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  63. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  64. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  65. msprobe/mindspore/common/const.py +61 -0
  66. msprobe/mindspore/common/utils.py +31 -19
  67. msprobe/mindspore/compare/ms_compare.py +27 -19
  68. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  69. msprobe/mindspore/debugger/debugger_config.py +6 -4
  70. msprobe/mindspore/debugger/precision_debugger.py +22 -10
  71. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  72. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  73. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  74. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  75. msprobe/mindspore/dump/jit_dump.py +14 -9
  76. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  77. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  78. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  79. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  80. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  81. msprobe/mindspore/grad_probe/global_context.py +2 -0
  82. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  83. msprobe/mindspore/grad_probe/hook.py +2 -4
  84. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  85. msprobe/mindspore/monitor/module_hook.py +354 -302
  86. msprobe/mindspore/monitor/utils.py +46 -4
  87. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  88. msprobe/mindspore/service.py +23 -17
  89. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  90. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
  91. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  92. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  93. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  94. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  95. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  96. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  97. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  98. msprobe/pytorch/common/utils.py +29 -7
  99. msprobe/pytorch/debugger/precision_debugger.py +10 -1
  100. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  101. msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
  102. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  103. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  104. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  105. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  106. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  107. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  108. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  109. msprobe/pytorch/function_factory.py +1 -1
  110. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  111. msprobe/pytorch/hook_module/api_register.py +131 -0
  112. msprobe/pytorch/hook_module/hook_module.py +19 -14
  113. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  114. msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
  115. msprobe/pytorch/monitor/csv2tb.py +8 -2
  116. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  117. msprobe/pytorch/monitor/module_hook.py +131 -105
  118. msprobe/pytorch/monitor/module_metric.py +3 -0
  119. msprobe/pytorch/monitor/optimizer_collect.py +55 -4
  120. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  121. msprobe/pytorch/monitor/utils.py +68 -1
  122. msprobe/pytorch/online_dispatch/compare.py +0 -2
  123. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  124. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  125. msprobe/pytorch/online_dispatch/utils.py +3 -0
  126. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  127. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  128. msprobe/pytorch/pt_config.py +11 -7
  129. msprobe/pytorch/service.py +11 -8
  130. msprobe/visualization/builder/graph_builder.py +44 -5
  131. msprobe/visualization/builder/msprobe_adapter.py +0 -1
  132. msprobe/visualization/compare/graph_comparator.py +42 -38
  133. msprobe/visualization/compare/mode_adapter.py +0 -19
  134. msprobe/visualization/graph/base_node.py +8 -1
  135. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  136. msprobe/visualization/graph/graph.py +0 -11
  137. msprobe/visualization/graph/node_op.py +1 -2
  138. msprobe/visualization/graph_service.py +1 -1
  139. msprobe/visualization/utils.py +2 -33
  140. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  141. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  142. msprobe/pytorch/hook_module/api_registry.py +0 -166
  143. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  144. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  145. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  146. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  147. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  148. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  149. msprobe/pytorch/parse.py +0 -19
  150. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  151. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  152. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  153. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -28,7 +28,7 @@ from msprobe.core.common.exceptions import DistributedNotInitializedError
28
28
  from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
29
29
  check_file_or_directory_path, check_path_before_create, FileOpen)
30
30
  from msprobe.core.common.log import logger
31
- from msprobe.core.common.utils import check_seed_all
31
+ from msprobe.core.common.utils import check_seed_all, is_save_variable_valid
32
32
  from packaging import version
33
33
 
34
34
  try:
@@ -57,7 +57,7 @@ def parameter_adapter(func):
57
57
 
58
58
  @wraps(func)
59
59
  def inner(self, *args, **kwargs):
60
- if self.op_name_ == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor):
60
+ if self.api_name == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor):
61
61
  input_tensor = args[0]
62
62
  indices = args[1]
63
63
  if indices.dtype == torch.uint8:
@@ -77,7 +77,7 @@ def parameter_adapter(func):
77
77
  else:
78
78
  res = [input_tensor[tensor_index] for tensor_index in indices]
79
79
  return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0)
80
- if self.op_name_ == "__eq__" and len(args) > 1 and args[1] is None:
80
+ if self.api_name == "__eq__" and len(args) > 1 and args[1] is None:
81
81
  return False
82
82
  return func(self, *args, **kwargs)
83
83
 
@@ -261,6 +261,10 @@ class Const:
261
261
  NPU = 'NPU'
262
262
  DISTRIBUTED = 'Distributed'
263
263
 
264
+ HIFLOAT8_TYPE = "torch_npu.HiFloat8Tensor"
265
+ FLOAT8_E5M2_TYPE = "torch.float8_e5m2"
266
+ FLOAT8_E4M3FN_TYPE = "torch.float8_e4m3fn"
267
+
264
268
  RAISE_PRECISION = {
265
269
  torch.float16: torch.float32,
266
270
  torch.bfloat16: torch.float32,
@@ -419,7 +423,11 @@ def is_recomputation():
419
423
  bool: True if in the re-computation phase, False otherwise.
420
424
  """
421
425
  backward_function_indices = []
422
- call_stack = inspect.stack()
426
+ try:
427
+ call_stack = inspect.stack()
428
+ except Exception as e:
429
+ logger.warning(f"Failed to capture stack trace, recomputation validation may be incorrect, error info: {e}.")
430
+ return False
423
431
 
424
432
  # Identify the function 'backward' is being executed within the 'torch/_tensor.py' file.
425
433
  for frame_info in call_stack:
@@ -449,9 +457,11 @@ def is_recomputation():
449
457
 
450
458
  def check_save_param(variable, name, save_backward):
451
459
  # try catch this api to skip invalid call
452
- if not isinstance(variable, (list, dict, torch.Tensor, int, float, str)):
460
+ valid_data_types = tuple([torch.Tensor, int, float, str])
461
+ if not is_save_variable_valid(variable, valid_data_types):
462
+ valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list)
453
463
  logger.warning("PrecisionDebugger.save variable type not valid, "
454
- "should be one of list, dict, torch.Tensor, int, float or string. "
464
+ f"should be one of {valid_data_types_with_nested_types}"
455
465
  "Skip current save process.")
456
466
  raise ValueError
457
467
  if not isinstance(name, str):
@@ -473,3 +483,15 @@ def replace_last_occurrence(text, old, new):
473
483
  if index != -1:
474
484
  return text[:index] + text[index:].replace(old, new, 1)
475
485
  return text
486
+
487
+
488
+ def is_hifloat8_tensor(tensor):
489
+ if not is_gpu and hasattr(torch_npu, "HiFloat8Tensor") and isinstance(tensor, torch_npu.HiFloat8Tensor):
490
+ return True
491
+ return False
492
+
493
+
494
+ def is_float8_tensor(tensor):
495
+ if str(tensor.dtype) in [Const.FLOAT8_E5M2_TYPE, Const.FLOAT8_E4M3FN_TYPE]:
496
+ return True
497
+ return is_hifloat8_tensor(tensor)
@@ -19,7 +19,7 @@ import torch
19
19
  from msprobe.core.common.const import Const, FileCheckConst, MsgConst
20
20
  from msprobe.core.common.exceptions import MsprobeException
21
21
  from msprobe.core.common.file_utils import FileChecker
22
- from msprobe.core.common.utils import get_real_step_or_rank
22
+ from msprobe.core.common.utils import get_real_step_or_rank, check_init_step
23
23
  from msprobe.pytorch.common.log import logger
24
24
  from msprobe.pytorch.common.utils import check_save_param
25
25
  from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
@@ -172,6 +172,15 @@ class PrecisionDebugger:
172
172
  return
173
173
  instance.service.save(variable, name, save_backward)
174
174
 
175
+ @classmethod
176
+ def set_init_step(cls, step):
177
+ instance = cls._instance
178
+ if not instance:
179
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
180
+ check_init_step(step)
181
+ instance.service.init_step = step
182
+ instance.service.loop = 0
183
+
175
184
 
176
185
  def module_dump(module, dump_name):
177
186
  if not isinstance(module, torch.nn.Module):
@@ -17,7 +17,7 @@ import torch
17
17
  from msprobe.core.common.const import Const
18
18
  from msprobe.core.data_dump.scope import BaseScope
19
19
  from msprobe.pytorch.common.log import logger
20
- from msprobe.pytorch.hook_module.api_registry import api_register
20
+ from msprobe.pytorch.hook_module.api_register import get_api_register
21
21
 
22
22
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
23
23
 
@@ -26,13 +26,14 @@ class ModuleDumper:
26
26
  def __init__(self, service):
27
27
  self.service = service
28
28
  self.hook_handle_list = []
29
+ self.api_register = get_api_register()
29
30
 
30
31
  def start_module_dump(self, module, dump_name):
31
- api_register.api_originality()
32
+ self.api_register.restore_all_api()
32
33
  self.register_hook(module, dump_name)
33
34
 
34
35
  def stop_module_dump(self):
35
- api_register.api_modularity()
36
+ self.api_register.register_all_api()
36
37
  for hook_handle in self.hook_handle_list:
37
38
  if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
38
39
  hook_handle.remove()
@@ -16,15 +16,17 @@
16
16
  from functools import wraps
17
17
 
18
18
  import torch
19
+ from torch.utils.hooks import BackwardHook
20
+
19
21
  from msprobe.core.common.const import Const
22
+ from msprobe.core.common.decorator import recursion_depth_decorator
20
23
  from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
21
24
  from msprobe.pytorch.common.log import logger
22
- from msprobe.pytorch.common.utils import replace_last_occurrence
23
- from torch.utils.checkpoint import checkpoint as origin_checkpoint
24
- from torch.utils.checkpoint import set_checkpoint_early_stop
25
- from torch.utils.hooks import BackwardHook
25
+ from msprobe.pytorch.common.utils import replace_last_occurrence, is_float8_tensor
26
26
 
27
27
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
28
+ if torch_version_above_or_equal_2:
29
+ from torch.utils.checkpoint import checkpoint as origin_checkpoint, set_checkpoint_early_stop
28
30
 
29
31
 
30
32
  def checkpoint_without_early_stop(*args, **kwargs):
@@ -33,7 +35,8 @@ def checkpoint_without_early_stop(*args, **kwargs):
33
35
 
34
36
 
35
37
  def replace_checkpoint():
36
- torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
38
+ if torch_version_above_or_equal_2:
39
+ torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
37
40
 
38
41
 
39
42
  class ModuleProcesser:
@@ -58,8 +61,9 @@ class ModuleProcesser:
58
61
  return clone_return_value_func
59
62
 
60
63
  @staticmethod
64
+ @recursion_depth_decorator("ModuleDump: ModuleProcesser.clone_if_tensor", max_depth=Const.DUMP_MAX_DEPTH)
61
65
  def clone_if_tensor(result):
62
- if isinstance(result, torch.Tensor):
66
+ if isinstance(result, torch.Tensor) and not is_float8_tensor(result):
63
67
  return result.clone()
64
68
  elif type(result) is tuple:
65
69
  return tuple(ModuleProcesser.clone_if_tensor(x) for x in result)
@@ -109,6 +113,8 @@ class ModuleProcesser:
109
113
  for name, module in modules_and_names:
110
114
  if module == model:
111
115
  continue
116
+ if module.__class__.__name__ == "FullyShardedDataParallel":
117
+ continue
112
118
  module_index = (index + Const.SEP) if index != "-1" else ""
113
119
  prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index +
114
120
  name + Const.SEP + module.__class__.__name__ + Const.SEP)
@@ -16,7 +16,7 @@
16
16
 
17
17
  import torch
18
18
  from msprobe.core.common.exceptions import FreeBenchmarkException
19
- from msprobe.core.common.utils import recursion_depth_decorator
19
+ from msprobe.core.common.decorator import recursion_depth_decorator
20
20
  from msprobe.pytorch.free_benchmark.common.enums import DeviceType
21
21
 
22
22
 
@@ -16,7 +16,7 @@
16
16
  import math
17
17
 
18
18
  import torch
19
- from msprobe.core.common.utils import recursion_depth_decorator
19
+ from msprobe.core.common.decorator import recursion_depth_decorator
20
20
  from msprobe.pytorch.free_benchmark import logger
21
21
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
22
22
  from msprobe.pytorch.free_benchmark.common.utils import TorchC
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import torch
17
- from msprobe.core.common.utils import recursion_depth_decorator
17
+ from msprobe.core.common.decorator import recursion_depth_decorator
18
18
  from msprobe.pytorch.free_benchmark import logger
19
19
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
20
20
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -95,13 +95,13 @@ class AddNoiseLayer(NpuBaseLayer):
95
95
  except Exception:
96
96
  logger.warning_on_rank_0(
97
97
  f"[msprobe] Free Benchmark: For {self.api_name}, "
98
- f"when calculate maximun value, tensor is changed to float32."
98
+ f"when calculating the maximum value, the tensor is changed to float32."
99
99
  )
100
100
  max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
101
101
  if max_val < abs_tol:
102
102
  logger.warning_on_rank_0(
103
103
  f"[msprobe] Free Benchmark: For {self.api_name}, "
104
- f"Maximun value is less than the minimun threshold. Cancel add noise."
104
+ f"maximum value is less than the minimum threshold. Cancel adding noise."
105
105
  )
106
106
  return False
107
107
  return True
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import torch
17
- from msprobe.core.common.utils import recursion_depth_decorator
17
+ from msprobe.core.common.decorator import recursion_depth_decorator
18
18
  from msprobe.pytorch.free_benchmark import logger
19
19
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
20
20
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -100,13 +100,13 @@ class BitNoiseLayer(NpuBaseLayer):
100
100
  except Exception:
101
101
  logger.warning_on_rank_0(
102
102
  f"[msprobe] Free Benchmark: For {self.api_name}, "
103
- f"when calculate maximun value, tensor is changed to float32."
103
+ f"when calculate the maximum value, the tensor is changed to float32."
104
104
  )
105
105
  max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
106
106
  if max_val < abs_tol:
107
107
  logger.warning_on_rank_0(
108
108
  f"[msprobe] Free Benchmark: For {self.api_name}, "
109
- f"Maximun value is less than the minimun threshold. Cancel add noise."
109
+ f"maximum value is less than the minimum threshold. Cancel adding noise."
110
110
  )
111
111
  return False
112
112
  return True
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import torch
17
- from msprobe.core.common.utils import recursion_depth_decorator
17
+ from msprobe.core.common.decorator import recursion_depth_decorator
18
18
  from msprobe.pytorch.free_benchmark import logger
19
19
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
20
20
  from msprobe.pytorch.free_benchmark.common.params import DataParams
@@ -15,7 +15,7 @@
15
15
 
16
16
  import torch
17
17
  from msprobe.core.common.const import Const
18
- from msprobe.core.common.utils import recursion_depth_decorator
18
+ from msprobe.core.common.decorator import recursion_depth_decorator
19
19
  from msprobe.pytorch.free_benchmark import logger
20
20
  from msprobe.pytorch.free_benchmark.common.constant import CommonField
21
21
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -49,6 +49,6 @@ class CheckerHandler(FuzzHandler):
49
49
  except Exception as e:
50
50
  logger.warning_on_rank_0(
51
51
  f"[msprobe] Free Benchmark: For {self.params.api_name}, "
52
- f"when campare the result exception raise {e}"
52
+ f"when comparing the results, an exception is raised: {e}"
53
53
  )
54
54
  return data_params.original_result
@@ -70,7 +70,7 @@ class Register(dict):
70
70
 
71
71
  def add_register_item(key, value):
72
72
  if key in self._dict:
73
- logger.warning(f"{value.__name__} has been registered before, so we will overriden it.")
73
+ logger.warning(f"{value.__name__} has been registered before, so we will override it.")
74
74
  self[key] = value
75
75
  return value
76
76
 
@@ -46,7 +46,7 @@ class GradientMonitor:
46
46
  if not os.path.exists(self._output_path):
47
47
  create_directory(self._output_path)
48
48
  else:
49
- logger.warning(f"the file in {self._output_path} will be recoverd")
49
+ logger.warning(f"the file in {self._output_path} will be deleted")
50
50
  self._step = -1
51
51
  self._param2name = defaultdict(str)
52
52
 
@@ -97,7 +97,7 @@ class GradientMonitor:
97
97
  create_directory(output_dirpath)
98
98
  output_path = os.path.join(output_dirpath, f"grad_summary_{self._step}.csv")
99
99
  if os.path.exists(output_path):
100
- logger.warning(f"{output_path} will be recoverd")
100
+ logger.warning(f"{output_path} will be deleted")
101
101
  remove_path(output_path)
102
102
  header_result = GradStatCsv.generate_csv_header(self._level_adp, self._bounds)
103
103
  output_lines.insert(0, header_result)
@@ -0,0 +1,131 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ import os
18
+
19
+ import torch
20
+ import torch.distributed as dist
21
+
22
+ from msprobe.core.common.const import Const
23
+ from msprobe.core.data_dump.api_registry import ApiRegistry
24
+ from msprobe.pytorch.common.utils import (
25
+ torch_without_guard_version, is_gpu, torch_device_guard, parameter_adapter
26
+ )
27
+ from msprobe.pytorch.function_factory import npu_custom_functions
28
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
29
+
30
+
31
+ torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
32
+
33
+ _api_types = {
34
+ Const.PT_FRAMEWORK: {
35
+ Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
36
+ Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)),
37
+ Const.PT_API_TYPE_TORCH: (torch, (torch,)),
38
+ Const.PT_API_TYPE_VF: (torch._C._VariableFunctionsClass, (torch._VF,)),
39
+ Const.PT_API_TYPE_DIST: (dist, (dist, dist.distributed_c10d))
40
+ }
41
+ }
42
+ if not is_gpu:
43
+ import torch_npu
44
+ if torch_without_guard_version:
45
+ _api_types.get(Const.PT_FRAMEWORK).update(
46
+ {
47
+ Const.PT_API_TYPE_NPU: (torch.ops.npu, (torch_npu, torch.ops.npu))
48
+ }
49
+ )
50
+ else:
51
+ _api_types.get(Const.PT_FRAMEWORK).update(
52
+ {Const.PT_API_TYPE_NPU: (torch_npu._C._VariableFunctionsClass, (torch_npu,))}
53
+ )
54
+ _api_types.get(Const.PT_FRAMEWORK).update(
55
+ {
56
+ Const.PT_API_TYPE_NPU_DIST: (torch_npu.distributed, (torch_npu.distributed,
57
+ torch_npu.distributed.distributed_c10d))
58
+ }
59
+ )
60
+
61
+ _inner_used_api = {}
62
+ _supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
63
+ _cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
64
+
65
+
66
+ @parameter_adapter
67
+ def tensor_module_forward(module, *args, **kwargs):
68
+ return module.api_func(*args, **kwargs)
69
+
70
+
71
+ def dist_module_forward(module, *args, **kwargs):
72
+ handle = module.api_func(*args, **kwargs)
73
+ if kwargs.get("async_op") or module.api_name in ["isend", "irecv"]:
74
+ if handle and hasattr(handle, 'wait'):
75
+ handle.wait()
76
+ if module.api_name == "batch_isend_irecv":
77
+ if isinstance(handle, list):
78
+ for req in handle:
79
+ req.wait()
80
+ return handle
81
+
82
+
83
+ def npu_module_forward(module, *args, **kwargs):
84
+ if not module.need_hook:
85
+ if module.api_name not in npu_custom_functions:
86
+ raise Exception(f'There is not bench function {module.api_name}')
87
+ if module.device == Const.CUDA_LOWERCASE:
88
+ module.api_name = _cuda_func_mapping.get(module.api_name, module.api_name)
89
+ if module.device in [Const.CUDA_LOWERCASE, Const.CPU_LOWERCASE]:
90
+ return npu_custom_functions[module.api_name](*args, **kwargs)
91
+ return module.api_func(*args, **kwargs)
92
+
93
+
94
+ forward_methods = {
95
+ "Tensor": tensor_module_forward,
96
+ "Distributed": dist_module_forward,
97
+ "NPU": npu_module_forward
98
+ }
99
+
100
+
101
+ class ApiTemplate(HOOKModule):
102
+ def __init__(self, api_name, api_func, prefix, hook_build_func, need_hook=True, device=Const.CPU_LOWERCASE):
103
+ self.api_name = api_name
104
+ self.api_func = api_func
105
+ self.prefix = prefix
106
+ self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
107
+ self.need_hook = need_hook
108
+ self.device = device
109
+ if self.need_hook:
110
+ super().__init__(hook_build_func)
111
+ if prefix == Const.DIST_API_TYPE_PREFIX:
112
+ self.op_is_distributed = True
113
+
114
+ @torch_device_guard
115
+ def forward(self, *args, **kwargs):
116
+ exec_func = forward_methods.get(self.prefix)
117
+ exec_func = functools.partial(exec_func, self) if exec_func else self.api_func
118
+ return exec_func(*args, **kwargs)
119
+
120
+
121
+ api_register = None
122
+
123
+
124
+ def get_api_register(return_new=False):
125
+ if return_new:
126
+ return ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
127
+
128
+ global api_register
129
+ if api_register is None:
130
+ api_register = ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
131
+ return api_register
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,6 +21,8 @@ import torch
21
21
  import torch.nn as nn
22
22
  import torch.utils.hooks as full_hooks
23
23
 
24
+ from msprobe.pytorch.common.utils import is_float8_tensor
25
+
24
26
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
25
27
 
26
28
 
@@ -28,28 +30,27 @@ class HOOKModule(nn.Module):
28
30
  module_count = defaultdict(int)
29
31
  inner_stop_hook = {}
30
32
 
31
- def __init__(self, build_hook) -> None:
33
+ def __init__(self, hook_build_func) -> None:
32
34
  super(HOOKModule, self).__init__()
33
35
  self.has_overflow = False
34
- self.prefix = ""
35
36
  self.current_thread = threading.current_thread().ident
36
37
  if self.current_thread not in HOOKModule.inner_stop_hook:
37
38
  HOOKModule.inner_stop_hook[self.current_thread] = False
38
39
  self.stop_hook = HOOKModule.inner_stop_hook.get(self.current_thread, False)
39
40
 
40
41
  if not self.stop_hook:
41
- if hasattr(self, "prefix_op_name_"):
42
- self.prefix = self.prefix_op_name_
43
-
44
42
  self.forward_data_collected = False
45
- forward_pre_hook, forward_hook, backward_hook, _ = build_hook(self.prefix)
46
- if torch_version_above_or_equal_2:
47
- self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
48
- self.register_forward_hook(forward_hook, with_kwargs=True)
49
- else:
50
- self.register_forward_pre_hook(forward_pre_hook)
51
- self.register_forward_hook(forward_hook)
52
- self.register_backward_hook(backward_hook)
43
+
44
+ prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
45
+ if callable(hook_build_func):
46
+ forward_pre_hook, forward_hook, backward_hook, _ = hook_build_func(prefix)
47
+ if torch_version_above_or_equal_2:
48
+ self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
49
+ self.register_forward_hook(forward_hook, with_kwargs=True)
50
+ else:
51
+ self.register_forward_pre_hook(forward_pre_hook)
52
+ self.register_forward_hook(forward_hook)
53
+ self.register_backward_hook(backward_hook)
53
54
 
54
55
  def __call__(self, *args, **kwargs):
55
56
  changed = False
@@ -111,6 +112,10 @@ class HOOKModule(nn.Module):
111
112
  return result
112
113
  else:
113
114
  return result
115
+
116
+ if is_float8_tensor(var) or not (var.requires_grad and torch.is_grad_enabled()):
117
+ return result
118
+
114
119
  grad_fn = var.grad_fn
115
120
  if grad_fn is not None:
116
121
  for hook in non_full_backward_hooks:
@@ -32,8 +32,9 @@ def register_optimizer_hook(data_collector):
32
32
  def patch_clip_grad(func):
33
33
  def wrapper(*args, **kwargs):
34
34
  data_collector.optimizer_status = Const.CLIP_GRAD
35
- func(*args, **kwargs)
35
+ result = func(*args, **kwargs)
36
36
  data_collector.optimizer_status = Const.END_PREFIX + Const.CLIP_GRAD
37
+ return result
37
38
 
38
39
  return wrapper
39
40