mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
  3. msprobe/README.md +25 -20
  4. msprobe/core/common/const.py +110 -66
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/utils.py +30 -34
  9. msprobe/core/compare/acc_compare.py +43 -74
  10. msprobe/core/compare/check.py +2 -6
  11. msprobe/core/compare/highlight.py +2 -0
  12. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  13. msprobe/core/compare/merge_result/merge_result.py +8 -2
  14. msprobe/core/compare/multiprocessing_compute.py +19 -12
  15. msprobe/core/compare/npy_compare.py +30 -12
  16. msprobe/core/compare/utils.py +20 -10
  17. msprobe/core/data_dump/api_registry.py +176 -0
  18. msprobe/core/data_dump/data_processor/base.py +2 -2
  19. msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
  20. msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
  21. msprobe/core/data_dump/json_writer.py +38 -35
  22. msprobe/core/grad_probe/constant.py +1 -0
  23. msprobe/core/grad_probe/grad_compare.py +1 -1
  24. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  25. msprobe/docs/01.installation.md +2 -1
  26. msprobe/docs/02.config_introduction.md +17 -15
  27. msprobe/docs/05.data_dump_PyTorch.md +70 -2
  28. msprobe/docs/06.data_dump_MindSpore.md +33 -12
  29. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  30. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  31. msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
  32. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  33. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  34. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  35. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  36. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  37. msprobe/docs/18.online_dispatch.md +1 -1
  38. msprobe/docs/19.monitor.md +124 -62
  39. msprobe/docs/21.visualization_PyTorch.md +32 -13
  40. msprobe/docs/22.visualization_MindSpore.md +32 -13
  41. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  42. msprobe/docs/27.dump_json_instruction.md +278 -8
  43. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  44. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  45. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  46. msprobe/docs/FAQ.md +3 -11
  47. msprobe/docs/img/compare_result.png +0 -0
  48. msprobe/docs/img/merge_result.png +0 -0
  49. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  50. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  51. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  52. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  53. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  54. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  55. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  56. msprobe/mindspore/__init__.py +4 -3
  57. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
  58. msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
  59. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  60. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  61. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  62. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  63. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  64. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  65. msprobe/mindspore/common/const.py +61 -0
  66. msprobe/mindspore/common/utils.py +31 -19
  67. msprobe/mindspore/compare/ms_compare.py +27 -19
  68. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  69. msprobe/mindspore/debugger/debugger_config.py +6 -4
  70. msprobe/mindspore/debugger/precision_debugger.py +22 -10
  71. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  72. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  73. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  74. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  75. msprobe/mindspore/dump/jit_dump.py +14 -9
  76. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  77. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  78. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  79. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  80. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  81. msprobe/mindspore/grad_probe/global_context.py +2 -0
  82. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  83. msprobe/mindspore/grad_probe/hook.py +2 -4
  84. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  85. msprobe/mindspore/monitor/module_hook.py +354 -302
  86. msprobe/mindspore/monitor/utils.py +46 -4
  87. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  88. msprobe/mindspore/service.py +23 -17
  89. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  90. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
  91. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  92. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  93. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  94. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  95. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  96. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  97. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  98. msprobe/pytorch/common/utils.py +29 -7
  99. msprobe/pytorch/debugger/precision_debugger.py +10 -1
  100. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  101. msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
  102. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  103. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  104. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  105. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  106. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  107. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  108. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  109. msprobe/pytorch/function_factory.py +1 -1
  110. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  111. msprobe/pytorch/hook_module/api_register.py +131 -0
  112. msprobe/pytorch/hook_module/hook_module.py +19 -14
  113. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  114. msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
  115. msprobe/pytorch/monitor/csv2tb.py +8 -2
  116. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  117. msprobe/pytorch/monitor/module_hook.py +131 -105
  118. msprobe/pytorch/monitor/module_metric.py +3 -0
  119. msprobe/pytorch/monitor/optimizer_collect.py +55 -4
  120. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  121. msprobe/pytorch/monitor/utils.py +68 -1
  122. msprobe/pytorch/online_dispatch/compare.py +0 -2
  123. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  124. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  125. msprobe/pytorch/online_dispatch/utils.py +3 -0
  126. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  127. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  128. msprobe/pytorch/pt_config.py +11 -7
  129. msprobe/pytorch/service.py +11 -8
  130. msprobe/visualization/builder/graph_builder.py +44 -5
  131. msprobe/visualization/builder/msprobe_adapter.py +0 -1
  132. msprobe/visualization/compare/graph_comparator.py +42 -38
  133. msprobe/visualization/compare/mode_adapter.py +0 -19
  134. msprobe/visualization/graph/base_node.py +8 -1
  135. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  136. msprobe/visualization/graph/graph.py +0 -11
  137. msprobe/visualization/graph/node_op.py +1 -2
  138. msprobe/visualization/graph_service.py +1 -1
  139. msprobe/visualization/utils.py +2 -33
  140. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  141. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  142. msprobe/pytorch/hook_module/api_registry.py +0 -166
  143. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  144. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  145. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  146. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  147. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  148. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  149. msprobe/pytorch/parse.py +0 -19
  150. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  151. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  152. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  153. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,176 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Dict, Any, Optional, Callable, Union, List, Tuple
17
+
18
+ from msprobe.core.common.const import Const
19
+ from msprobe.core.common.file_utils import load_yaml
20
+
21
+
22
+ def _get_attr(module, attr_name):
23
+ if Const.SEP in attr_name:
24
+ sub_module_name, sub_attr = attr_name.rsplit(Const.SEP, 1)
25
+ sub_module = getattr(module, sub_module_name, None)
26
+ attr = getattr(sub_module, sub_attr, None)
27
+ else:
28
+ attr = getattr(module, attr_name, None)
29
+ return attr
30
+
31
+
32
+ class ApiWrapper:
33
+ def __init__(
34
+ self, api_types: Dict[str, Dict[str, Any]],
35
+ api_list_paths: Union[str, List[str], Tuple[str]]
36
+ ):
37
+ self.api_types = api_types
38
+ if not isinstance(api_list_paths, (list, tuple)):
39
+ api_list_paths = [api_list_paths] * len(self.api_types)
40
+ elif len(api_list_paths) != len(self.api_types):
41
+ raise RuntimeError("The number of api_list_paths must be equal to the number of frameworks in 'api_types', "
42
+ "when api_list_paths is a list or tuple.")
43
+ self.api_list_paths = api_list_paths
44
+ self.api_names = self._get_api_names()
45
+ self.wrapped_api_functions = dict()
46
+
47
+ def wrap_api(
48
+ self, api_templates, hook_build_func: Optional[Callable]
49
+ ):
50
+ api_types_num = sum([len(v) for v in self.api_types.values()])
51
+ if not isinstance(api_templates, (list, tuple)):
52
+ api_templates = [api_templates] * api_types_num
53
+ elif len(api_templates) != api_types_num:
54
+ raise RuntimeError("The number of api_templates must be equal to the number of api_types, "
55
+ "when api_templates is a list or tuple.")
56
+
57
+ self.wrapped_api_functions.clear()
58
+ index = 0
59
+ for framework, api_types in self.api_types.items():
60
+ wrapped_functions_in_framework = dict()
61
+ for api_type, api_modules in api_types.items():
62
+ wrapped_functions = dict()
63
+ name_prefix = Const.API_DATA_PREFIX.get(framework, {}).get(api_type, "API")
64
+ api_template = api_templates[index]
65
+ index += 1
66
+ for api_name in self.api_names.get(framework, {}).get(api_type, []):
67
+ ori_api = _get_attr(api_modules[0], api_name)
68
+ if callable(ori_api):
69
+ def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template):
70
+ def api_function(*args, **kwargs):
71
+ return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs)
72
+ api_function.__name__ = api_name
73
+ return api_function
74
+ wrapped_functions[api_name] = wrap_api_func(api_name, ori_api, name_prefix,
75
+ hook_build_func, api_template)
76
+ wrapped_functions_in_framework[api_type] = wrapped_functions
77
+ self.wrapped_api_functions[framework] = wrapped_functions_in_framework
78
+ return self.wrapped_api_functions
79
+
80
+ def _get_api_names(self):
81
+ api_names = dict()
82
+
83
+ for index, framework in enumerate(self.api_types.keys()):
84
+ api_list = load_yaml(self.api_list_paths[index])
85
+ valid_names = dict()
86
+ for api_type, api_modules in self.api_types.get(framework, {}).items():
87
+ api_from_file = api_list.get(Const.SUPPORT_API_DICT_KEY_MAP.get(framework, {}).get(api_type), [])
88
+ names = set()
89
+ for api_name in api_from_file:
90
+ target_attr = api_name
91
+ target_module = api_modules[0]
92
+ if Const.SEP in api_name:
93
+ sub_module_name, target_attr = api_name.rsplit(Const.SEP, 1)
94
+ target_module = getattr(api_modules[0], sub_module_name, None)
95
+ if target_module and target_attr in dir(target_module):
96
+ names.add(api_name)
97
+ valid_names[api_type] = names
98
+ api_names[framework] = valid_names
99
+
100
+ return api_names
101
+
102
+
103
+ class ApiRegistry:
104
+ """
105
+ Base class for api registry.
106
+ """
107
+
108
+ def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates):
109
+ self.ori_api_attr = dict()
110
+ self.wrapped_api_attr = dict()
111
+ self.inner_used_ori_attr = dict()
112
+ self.inner_used_wrapped_attr = dict()
113
+ self.api_types = api_types
114
+ self.inner_used_api = inner_used_api
115
+ self.supported_api_list_path = supported_api_list_path
116
+ self.api_templates = api_templates
117
+
118
+ @staticmethod
119
+ def store_ori_attr(ori_api_group, api_list, api_ori_attr):
120
+ for api in api_list:
121
+ api_ori_attr[api] = _get_attr(ori_api_group, api)
122
+
123
+ @staticmethod
124
+ def set_api_attr(api_group, attr_dict):
125
+ for api, api_attr in attr_dict.items():
126
+ if Const.SEP in api:
127
+ sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
128
+ sub_module = getattr(api_group, sub_module_name, None)
129
+ if sub_module is not None:
130
+ setattr(sub_module, sub_op, api_attr)
131
+ else:
132
+ setattr(api_group, api, api_attr)
133
+
134
+ def register_all_api(self):
135
+ for framework, api_types in self.api_types.items():
136
+ for api_type, api_modules in api_types.items():
137
+ api_type_with_framework = framework + Const.SEP + api_type
138
+ for module in api_modules[1]:
139
+ self.set_api_attr(module, self.wrapped_api_attr.get(api_type_with_framework, {}))
140
+
141
+ def register_inner_used_api(self):
142
+ for api_type in self.inner_used_api.keys():
143
+ self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_wrapped_attr.get(api_type, {}))
144
+
145
+ def restore_all_api(self):
146
+ for framework, api_types in self.api_types.items():
147
+ for api_type, api_modules in api_types.items():
148
+ api_type_with_framework = framework + Const.SEP + api_type
149
+ for module in api_modules[1]:
150
+ self.set_api_attr(module, self.ori_api_attr.get(api_type_with_framework, {}))
151
+
152
+ def restore_inner_used_api(self):
153
+ for api_type in self.inner_used_api.keys():
154
+ self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_ori_attr.get(api_type, {}))
155
+
156
+ def initialize_hook(self, hook_build_func):
157
+ api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path)
158
+ wrapped_api_functions = api_wrapper.wrap_api(self.api_templates, hook_build_func)
159
+
160
+ for framework, api_types in self.api_types.items():
161
+ for api_type, api_modules in api_types.items():
162
+ ori_attr = dict()
163
+ self.store_ori_attr(api_modules[0], api_wrapper.api_names.get(framework).get(api_type), ori_attr)
164
+ api_type_with_framework = framework + Const.SEP + api_type
165
+ self.ori_api_attr[api_type_with_framework] = ori_attr
166
+ self.wrapped_api_attr[api_type_with_framework] = wrapped_api_functions.get(framework).get(api_type)
167
+
168
+ for inner_used_api_type, inner_used_api_list in self.inner_used_api.items():
169
+ ori_attr = dict()
170
+ wrapped_attr = dict()
171
+ for api_name in inner_used_api_list[1:]:
172
+ if self.ori_api_attr.get(inner_used_api_type, {}).get(api_name):
173
+ ori_attr[api_name] = self.ori_api_attr.get(inner_used_api_type).get(api_name)
174
+ wrapped_attr[api_name] = self.wrapped_api_attr.get(inner_used_api_type).get(api_name)
175
+ self.inner_used_ori_attr[inner_used_api_type] = ori_attr
176
+ self.inner_used_wrapped_attr[inner_used_api_type] = wrapped_attr
@@ -252,8 +252,8 @@ class BaseDataProcessor:
252
252
 
253
253
  @classmethod
254
254
  def recursive_apply_transform(cls, args, transform, depth=0) -> Union[dict, list, None]:
255
- if depth > Const.MAX_DEPTH:
256
- logger.error(f"The maximum depth of recursive transform, {Const.MAX_DEPTH} is reached.")
255
+ if depth > Const.DUMP_MAX_DEPTH:
256
+ logger.error(f"The maximum depth of recursive transform, {Const.DUMP_MAX_DEPTH} is reached.")
257
257
  raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
258
258
  if isinstance(args, cls.get_special_types()):
259
259
  arg_transform = transform(args, cls._recursive_key_stack)
@@ -26,7 +26,7 @@ from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, Tenso
26
26
  from msprobe.core.common.file_utils import path_len_exceeds_limit, save_npy
27
27
  from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy
28
28
  from msprobe.mindspore.common.log import logger
29
- from msprobe.mindspore.dump.hook_cell.api_registry import api_register
29
+ from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
30
30
 
31
31
  has_adump = True
32
32
  try:
@@ -44,6 +44,7 @@ class MindsporeDataProcessor(BaseDataProcessor):
44
44
  "dtype": self.analyze_dtype_in_kwargs
45
45
  }
46
46
  self._async_dump_cache = {}
47
+ self.api_register = get_api_register()
47
48
 
48
49
  @staticmethod
49
50
  def get_md5_for_tensor(x):
@@ -74,46 +75,29 @@ class MindsporeDataProcessor(BaseDataProcessor):
74
75
  else:
75
76
  if not ops.is_floating_point(data) or data.dtype == ms.float64:
76
77
  data = data.to(ms.float32)
77
- api_register.norm_inner_op_set_ori_func()
78
- get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max)
79
- get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min)
80
- get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean)
81
- if hasattr(mint, "norm"):
82
- get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm)
83
- else:
84
- get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
85
- tensor_stat.max = get_max_value(data).item()
86
- tensor_stat.min = get_min_value(data).item()
87
- tensor_stat.mean = get_mean_value(data).item()
78
+ get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
79
+ tensor_stat.max = mint.max(data).item()
80
+ tensor_stat.min = mint.min(data).item()
81
+ tensor_stat.mean = mint.mean(data).item()
88
82
  tensor_stat.norm = get_norm_value(data).item()
89
- api_register.norm_inner_op_set_hook_func()
90
83
  return tensor_stat
91
84
 
92
85
  @staticmethod
93
86
  def get_stat_info_async(data):
94
87
  tensor_stat = TensorStatInfo()
95
- stack_method = api_register.functional_ori_attr.get("stack", ms.ops.stack)
96
88
  if data.dtype == ms.complex64 or data.dtype == ms.complex128:
97
89
  logger.warning("Async dump do not support complex data!")
98
90
  return tensor_stat
99
91
  elif data.dtype == ms.bool_:
100
- tensor_stat.stack_tensor_stat = (["Max", "Min"], stack_method([data.any(), data.all()]))
92
+ tensor_stat.stack_tensor_stat = (["Max", "Min"], ops.stack([data.any(), data.all()]))
101
93
  elif not data.shape:
102
- tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method([data, data, data, data]))
94
+ tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], ops.stack([data, data, data, data]))
103
95
  else:
104
96
  if not ops.is_floating_point(data) or data.dtype == ms.float64:
105
97
  data = data.to(ms.float32)
106
- api_register.norm_inner_op_set_ori_func()
107
- get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max)
108
- get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min)
109
- get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean)
110
- if hasattr(mint, "norm"):
111
- get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm)
112
- else:
113
- get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
114
- tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method(
115
- [get_max_value(data), get_min_value(data), get_mean_value(data), get_norm_value(data)]))
116
- api_register.norm_inner_op_set_hook_func()
98
+ get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
99
+ tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], ops.stack(
100
+ [mint.max(data), mint.min(data), mint.mean(data), get_norm_value(data)]))
117
101
  return tensor_stat
118
102
 
119
103
  @staticmethod
@@ -125,14 +109,17 @@ class MindsporeDataProcessor(BaseDataProcessor):
125
109
  return super().get_special_types() + cls.mindspore_special_type
126
110
 
127
111
  def get_stat_info(self, data):
112
+ self.api_register.restore_inner_used_api()
128
113
  tensor_stat = TensorStatInfo()
129
114
  if data.numel() == 0:
130
- return tensor_stat
115
+ stat_info = tensor_stat
131
116
  else:
132
117
  if self.config.async_dump:
133
- return MindsporeDataProcessor.get_stat_info_async(data)
118
+ stat_info = MindsporeDataProcessor.get_stat_info_async(data)
134
119
  else:
135
- return MindsporeDataProcessor.get_stat_info_sync(data)
120
+ stat_info = MindsporeDataProcessor.get_stat_info_sync(data)
121
+ self.api_register.register_inner_used_api()
122
+ return stat_info
136
123
 
137
124
  def analyze_single_element(self, element, suffix_stack):
138
125
  if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
@@ -191,7 +178,7 @@ class TensorDataProcessor(MindsporeDataProcessor):
191
178
  else:
192
179
  save_tensor_as_npy(tensor, file_path)
193
180
  return single_arg
194
-
181
+
195
182
  def _analyze_numpy(self, ndarray, suffix):
196
183
  dump_data_name, file_path = self.get_save_file_path(suffix)
197
184
  save_npy(ndarray, file_path)
@@ -244,7 +231,7 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
244
231
  api_info_struct = super().analyze_backward(name, module, module_input_output)
245
232
  self.maybe_save_overflow_data()
246
233
  return api_info_struct if self.has_overflow else None
247
-
234
+
248
235
  def analyze_params(self, name, param_name, grad):
249
236
  self.has_overflow = False
250
237
  api_info_struct = super().analyze_params(name, param_name, grad)
@@ -24,14 +24,15 @@ from torch import distributed as dist
24
24
  from torch.distributed.distributed_c10d import _get_default_group
25
25
 
26
26
  from msprobe.core.common.const import Const
27
+ from msprobe.core.common.exceptions import MsprobeException
27
28
  from msprobe.core.common.file_utils import path_len_exceeds_limit
28
29
  from msprobe.core.common.log import logger
29
30
  from msprobe.core.common.utils import convert_tuple
31
+ from msprobe.core.common.decorator import recursion_depth_decorator
30
32
  from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
31
33
  ModuleForwardInputsOutputs, TensorStatInfo
32
- from msprobe.pytorch.common.utils import save_pt, load_pt
34
+ from msprobe.pytorch.common.utils import Const as PtConst, save_pt, is_hifloat8_tensor, is_float8_tensor
33
35
  from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
34
- from msprobe.core.common.utils import recursion_depth_decorator
35
36
 
36
37
  is_gpu = False
37
38
  try:
@@ -78,14 +79,16 @@ class PytorchDataProcessor(BaseDataProcessor):
78
79
  def analyze_device_in_kwargs(element):
79
80
  single_arg = {}
80
81
  single_arg.update({'type': "torch.device"})
81
- if not isinstance(element, str):
82
+ if isinstance(element, (int, str)):
83
+ single_arg.update({"value": element})
84
+ elif isinstance(element, torch.device):
82
85
  if hasattr(element, "index"):
83
86
  device_value = element.type + ":" + str(element.index)
84
87
  else:
85
88
  device_value = element.type
86
89
  single_arg.update({"value": device_value})
87
90
  else:
88
- single_arg.update({"value": element})
91
+ logger.debug(f"Device type {type(element)} is not supported.")
89
92
  return single_arg
90
93
 
91
94
  @staticmethod
@@ -143,7 +146,7 @@ class PytorchDataProcessor(BaseDataProcessor):
143
146
  if data.is_meta:
144
147
  return tensor_stat
145
148
  data_clone = data.detach()
146
- if data_clone.numel() == 0:
149
+ if not data_clone.numel() or not data_clone.data_ptr():
147
150
  return tensor_stat
148
151
  else:
149
152
  if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump:
@@ -214,6 +217,18 @@ class PytorchDataProcessor(BaseDataProcessor):
214
217
  logger.warning(f"Failed to get value of torch.distributed.ReduceOp with error info: {e}.")
215
218
  return {"type": "torch.distributed.ReduceOp", "value": op_type}
216
219
 
220
+ @staticmethod
221
+ def _cast_to_float_if_fp8(tensor):
222
+ dtype = str(tensor.dtype)
223
+ if is_float8_tensor(tensor):
224
+ dtype = PtConst.HIFLOAT8_TYPE if is_hifloat8_tensor(tensor) else dtype
225
+ logger.debug(
226
+ f"The {dtype} tensor analyzing/saving is unsupported in dump function."
227
+ f"Casting to float for processing."
228
+ )
229
+ tensor = tensor.float()
230
+ return tensor, dtype
231
+
217
232
  @classmethod
218
233
  def get_special_types(cls):
219
234
  return super().get_special_types() + cls.pytorch_special_type
@@ -228,7 +243,7 @@ class PytorchDataProcessor(BaseDataProcessor):
228
243
  if isinstance(element, dist.ProcessGroup):
229
244
  return self._analyze_process_group(element)
230
245
  if isinstance(element, dist.P2POp):
231
- return self._analyze_p2pop(element)
246
+ return self._analyze_p2pop(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
232
247
  if isinstance(element, dist.ReduceOp):
233
248
  return self._analyze_reduce_op(element)
234
249
  converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
@@ -247,10 +262,10 @@ class PytorchDataProcessor(BaseDataProcessor):
247
262
  module_input_output.update_output_with_args_and_kwargs()
248
263
  return super().analyze_forward_output(name, module, module_input_output)
249
264
 
250
- def _analyze_p2pop(self, arg):
265
+ def _analyze_p2pop(self, arg, suffix):
251
266
  p2pop_info = {"class_type": "torch.distributed.P2POp"}
252
267
  try:
253
- tensor_info = self._analyze_tensor(arg.tensor, [])
268
+ tensor_info = self._analyze_tensor(arg.tensor, suffix)
254
269
  p2pop_info.update({"tensor": tensor_info})
255
270
  p2pop_info.update({"op": arg.op.__name__})
256
271
  p2pop_info.update({"peer": arg.peer})
@@ -263,10 +278,11 @@ class PytorchDataProcessor(BaseDataProcessor):
263
278
  return p2pop_info
264
279
 
265
280
  def _analyze_tensor(self, tensor, suffix):
281
+ tensor, dtype = self._cast_to_float_if_fp8(tensor)
266
282
  tensor_stat = self.get_stat_info(tensor, self.config.async_dump)
267
283
  tensor_json = {}
268
284
  tensor_json.update({'type': 'torch.Tensor'})
269
- tensor_json.update({'dtype': str(tensor.dtype)})
285
+ tensor_json.update({'dtype': dtype})
270
286
  tensor_json.update({"shape": tensor.shape})
271
287
  if tensor_stat.stack_tensor_stat is None:
272
288
  tensor_json.update({"Max": tensor_stat.max})
@@ -305,13 +321,14 @@ class TensorDataProcessor(PytorchDataProcessor):
305
321
  dump_data_name, file_path = self.get_save_file_path(suffix)
306
322
  single_arg = super()._analyze_tensor(tensor, suffix)
307
323
  single_arg.update({"data_name": dump_data_name})
324
+ tensor, _ = self._cast_to_float_if_fp8(tensor)
308
325
  if self.config.async_dump:
309
326
  self._async_dump_cache[file_path] = tensor.clone().detach()
310
327
  else:
311
328
  saved_tensor = tensor.clone().contiguous().detach()
312
329
  save_pt(saved_tensor, file_path)
313
330
  return single_arg
314
-
331
+
315
332
  def _analyze_numpy(self, ndarray, suffix):
316
333
  dump_data_name, file_path = self.get_save_file_path(suffix)
317
334
  save_pt(torch.tensor(ndarray), file_path)
@@ -383,7 +400,8 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
383
400
  self._analyze_maybe_overflow_flag()
384
401
  if self.has_overflow:
385
402
  for file_path, tensor in self.cached_tensors_and_file_paths.items():
386
- save_pt(tensor, file_path)
403
+ tensor, _ = self._cast_to_float_if_fp8(tensor)
404
+ save_pt(tensor.clone().contiguous().detach(), file_path)
387
405
  self.real_overflow_nums += 1
388
406
  if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums:
389
407
  logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, "
@@ -508,11 +526,13 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
508
526
  return
509
527
 
510
528
  if self.config.is_backward_kernel_dump:
511
- self.forward_args = self.clone_and_detach_tensor(module_input_output.args)
512
- self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs)
513
529
  try:
530
+ self.forward_args = self.clone_and_detach_tensor(module_input_output.args)
531
+ self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs)
514
532
  output = module.forward(*self.forward_args, **self.forward_kwargs)
515
- except Exception:
533
+ except Exception as e:
534
+ if isinstance(e, MsprobeException):
535
+ logger.warning(str(e))
516
536
  self._print_unsupported_log(name)
517
537
  self.enable_kernel_dump = False
518
538
  return
@@ -554,9 +574,17 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
554
574
  self.stop_kernel_dump()
555
575
  logger.info(f"The kernel data of {name} is dumped successfully.")
556
576
 
557
- @recursion_depth_decorator("KernelDump: KernelDumpDataProcessor.clone_and_detach_tensor")
577
+ @recursion_depth_decorator(
578
+ "KernelDump: KernelDumpDataProcessor.clone_and_detach_tensor",
579
+ max_depth=Const.DUMP_MAX_DEPTH
580
+ )
558
581
  def clone_and_detach_tensor(self, input_params):
559
582
  if isinstance(input_params, torch.Tensor):
583
+ if is_float8_tensor(input_params):
584
+ raise MsprobeException(
585
+ MsprobeException.UNSUPPORTED_TYPE_ERROR,
586
+ f"L2 backward dump does not support float8 type."
587
+ )
560
588
  if input_params.requires_grad:
561
589
  return input_params.clone().detach().requires_grad_()
562
590
  return input_params.clone()
@@ -571,6 +599,8 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
571
599
 
572
600
  def analyze_single_element(self, element, suffix_stack):
573
601
  if isinstance(element, torch.Tensor):
602
+ if is_float8_tensor(element):
603
+ return {}
574
604
  if not self.is_found_output_tensor:
575
605
  if element.requires_grad:
576
606
  self.forward_output_tensor = element
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,12 +16,14 @@
16
16
  import csv
17
17
  import os
18
18
  import copy
19
- import numpy as np
19
+ import threading
20
20
 
21
21
  from msprobe.core.common.const import Const, FileCheckConst
22
22
  from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json
23
23
  from msprobe.core.common.log import logger
24
- from msprobe.core.common.exceptions import MsprobeException
24
+ from msprobe.core.common.decorator import recursion_depth_decorator
25
+
26
+ lock = threading.Lock()
25
27
 
26
28
 
27
29
  class DataWriter:
@@ -90,28 +92,32 @@ class DataWriter:
90
92
  self.write_json()
91
93
 
92
94
  def update_data(self, new_data):
93
- if not isinstance(new_data, dict) or len(new_data.keys()) != 1:
94
- logger.warning(f"The data info({new_data}) should be a dict with only one outer key.")
95
- return
96
- dump_data = self.cache_data.get(Const.DATA)
97
- if not isinstance(dump_data, dict):
98
- logger.warning(f"The dump data({dump_data}) should be a dict.")
99
- return
100
-
101
- key = next(iter(new_data.keys()))
102
- if key in dump_data:
103
- dump_data.get(key).update(new_data.get(key))
104
- else:
105
- dump_data.update(new_data)
95
+ with lock:
96
+ if not isinstance(new_data, dict) or len(new_data.keys()) != 1:
97
+ logger.warning(f"The data info({new_data}) should be a dict with only one outer key.")
98
+ return
99
+ dump_data = self.cache_data.get(Const.DATA)
100
+ if not isinstance(dump_data, dict):
101
+ logger.warning(f"The dump data({dump_data}) should be a dict.")
102
+ return
103
+
104
+ key = next(iter(new_data.keys()))
105
+ if key in dump_data:
106
+ dump_data.get(key).update(new_data.get(key))
107
+ else:
108
+ dump_data.update(new_data)
106
109
 
107
110
  def update_stack(self, new_data):
108
- self.cache_stack.update(new_data)
111
+ with lock:
112
+ self.cache_stack.update(new_data)
109
113
 
110
114
  def update_construct(self, new_data):
111
- self.cache_construct.update(new_data)
115
+ with lock:
116
+ self.cache_construct.update(new_data)
112
117
 
113
118
  def update_debug(self, new_data):
114
- self.cache_debug['data'].update(new_data)
119
+ with lock:
120
+ self.cache_debug['data'].update(new_data)
115
121
 
116
122
  def write_data_json(self, file_path):
117
123
  logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
@@ -127,22 +133,21 @@ class DataWriter:
127
133
  save_json(file_path, self.cache_debug, indent=1)
128
134
 
129
135
  def write_json(self):
130
- if self.cache_data:
131
- self.write_data_json(self.dump_file_path)
132
- if self.cache_stack:
133
- self.write_stack_info_json(self.stack_file_path)
134
- if self.cache_construct:
135
- self.write_construct_info_json(self.construct_file_path)
136
- if self.cache_debug:
137
- self.write_debug_info_json(self.debug_file_path)
136
+ with lock:
137
+ if self.cache_data:
138
+ self.write_data_json(self.dump_file_path)
139
+ if self.cache_stack:
140
+ self.write_stack_info_json(self.stack_file_path)
141
+ if self.cache_construct:
142
+ self.write_construct_info_json(self.construct_file_path)
143
+ if self.cache_debug:
144
+ self.write_debug_info_json(self.debug_file_path)
138
145
 
139
146
  def fill_stack_tensor_data(self):
140
147
  self.process_stat_data_recursive(self.cache_data)
141
148
 
142
- def process_stat_data_recursive(self, data, depth=0):
143
- if depth > Const.MAX_DEPTH:
144
- logger.error(f"The maximum depth of recursive process stat data, {Const.MAX_DEPTH} is reached.")
145
- raise MsprobeException(MsprobeException.RECURSION_LIMIT_ERROR)
149
+ @recursion_depth_decorator("AsyncDump: DataWriter.process_stat_data_recursive", max_depth=Const.DUMP_MAX_DEPTH)
150
+ def process_stat_data_recursive(self, data):
146
151
  if isinstance(data, dict):
147
152
  if "tensor_stat" in data.keys():
148
153
  tensor_stat = data["tensor_stat"]
@@ -150,14 +155,12 @@ class DataWriter:
150
155
  logger.warning("Some bad data in async dump")
151
156
  else:
152
157
  tensor_stat_index, tensor_stat_data = tensor_stat[0], tensor_stat[1]
153
- if hasattr(tensor_stat_data, "device") and tensor_stat_data.device != Const.CPU_LOWERCASE:
154
- tensor_stat_data = tensor_stat_data.cpu()
155
158
  for index, stat in zip(tensor_stat_index, tensor_stat_data):
156
159
  data.update({index: stat.item()})
157
160
  del data["tensor_stat"]
158
161
  else:
159
162
  for key in data.keys():
160
- self.process_stat_data_recursive(data[key], depth + 1)
163
+ self.process_stat_data_recursive(data[key])
161
164
  elif isinstance(data, (list, tuple)):
162
165
  for i in data:
163
- self.process_stat_data_recursive(i, depth + 1)
166
+ self.process_stat_data_recursive(i)
@@ -31,6 +31,7 @@ class GradConst:
31
31
  STEP = "step"
32
32
  BOUNDS = "bounds"
33
33
  OUTPUT_PATH = "output_path"
34
+ TIME_STAMP = "time_stamp"
34
35
 
35
36
  # level const
36
37
  LEVEL = "level"
@@ -112,7 +112,7 @@ class GradComparator:
112
112
  result.append([key] + value)
113
113
  result_csv_path = os.path.join(output_dir, "similarities.csv")
114
114
  if os.path.exists(result_csv_path):
115
- logger.warning(f"{result_csv_path} will be recoverd")
115
+ logger.warning(f"{result_csv_path} will be deleted")
116
116
  remove_path(result_csv_path)
117
117
  write_csv(result, result_csv_path)
118
118
 
@@ -20,6 +20,7 @@ import numpy as np
20
20
  from msprobe.core.overflow_check.api_info import APIInfo
21
21
  from msprobe.core.overflow_check.level import OverflowLevel
22
22
  from msprobe.core.overflow_check.utils import has_nan_inf
23
+ from msprobe.core.common.decorator import recursion_depth_decorator
23
24
 
24
25
 
25
26
  class AnomalyScene:
@@ -35,6 +36,7 @@ class AnomalyScene:
35
36
  raise NotImplementedError
36
37
 
37
38
  @staticmethod
39
+ @recursion_depth_decorator("AbnormalScene: AnomalyScene._has_anomaly")
38
40
  def _has_anomaly(data: Union[Dict, Any]) -> bool:
39
41
  """检查张量是否包含异常值"""
40
42
  if isinstance(data, dict):
@@ -16,6 +16,7 @@ pip install mindstudio-probe
16
16
 
17
17
  |版本|发布日期|支持 PyTorch 版本|支持 MindSpore 版本|下载链接|校验码|
18
18
  |:--:|:--:|:--:|:--:|:--:|:--:|
19
+ |1.2.2|2025.3.03|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.2-py3-none-any.whl)|961411bb460d327ea51d6ca4d0c8e8c5565f07c0852d7b8592b781ca35b87212|
19
20
  |1.2.1|2025.2.07|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.1-py3-none-any.whl)|b64b342118558e0339b39237f88a49b93fd24551b0cb202c872fbfef4260c86b|
20
21
  |1.2.0|2025.1.13|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.0-py3-none-any.whl)|1e3aeea1706112f6ee52fd1165037936bb209138f0b9ec42ea21e2c1c8942cdc|
21
22
  |1.1.1|2024.12.09|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.1.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.1/mindstudio_probe-1.1.1-py3-none-any.whl)|577b597555dc155b76ba1a62d575c3546004644e140a456c3ba0824d46283735|
@@ -51,7 +52,7 @@ pip install ./mindstudio_probe*.whl
51
52
 
52
53
  |参数|说明|是否必选|
53
54
  |--|--|:--:|
54
- |--include-mod|指定可选模块,可取值`adump`,表示在编whl包时加入adump模块。默认未配置该参数,表示编基础包。<br>&#8226; adump模块用于MindSpore静态图场景L2级别的dump。<br>&#8226; 仅MindSpore 2.5.0及以上版本支持adump模块。<br>&#8226; 若使用源码安装,编译环境需支持GCC 7或以上版本,和CMAKE 3.14或以上版本。<br>&#8226; 生成的whl包仅限编译时使用的python版本和处理器架构可用。|否|
55
+ |--include-mod|指定可选模块,可取值`adump`,表示在编whl包时加入adump模块。默认未配置该参数,表示编基础包。<br>&#8226; adump模块用于MindSpore静态图场景L2级别的dump。<br>&#8226; 仅MindSpore 2.5.0及以上版本支持adump模块。<br>&#8226; 若使用源码安装,编译环境需支持GCC 7.5或以上版本,和CMAKE 3.14或以上版本。<br>&#8226; 生成的whl包仅限编译时使用的python版本和处理器架构可用。|否|
55
56
 
56
57
  # 特性变更说明
57
58