mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -13,9 +13,10 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import atexit
16
17
  import os
17
18
 
18
- from msprobe.core.data_dump.scope import build_scope, ListScope
19
+ from msprobe.core.data_dump.scope import ScopeFactory
19
20
  from msprobe.core.data_dump.json_writer import DataWriter
20
21
  from msprobe.core.common.log import logger
21
22
  from msprobe.core.common.const import Const
@@ -27,7 +28,6 @@ def build_data_collector(config):
27
28
 
28
29
 
29
30
  class DataCollector:
30
- multi_output_apis = ["_sort_", "npu_flash_attention"]
31
31
  tasks_need_tensor_data = [Const.OVERFLOW_CHECK, Const.TENSOR, Const.FREE_BENCHMARK]
32
32
  level_without_construct = [Const.LEVEL_L1, Const.LEVEL_L2]
33
33
 
@@ -37,13 +37,8 @@ class DataCollector:
37
37
  self.data_processor = DataProcessorFactory.create_processor(self.config, self.data_writer)
38
38
  self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework)
39
39
  self.module_count = {}
40
- if self.config.task == Const.FREE_BENCHMARK:
41
- self.scope = build_scope(ListScope, self.config.scope, self.config.list)
42
- else:
43
- self.scope = build_scope(None, self.config.scope, self.config.list)
44
-
45
- def __del__(self):
46
- self.write_json()
40
+ self.scope = ScopeFactory(self.config).build_scope()
41
+ atexit.register(self.write_json)
47
42
 
48
43
  @property
49
44
  def dump_data_dir(self):
@@ -85,6 +80,10 @@ class DataCollector:
85
80
  self.data_writer.update_data(data_info)
86
81
 
87
82
  def pre_forward_data_collect(self, name, module, pid, module_input_output):
83
+ if self.config.level == Const.LEVEL_L2 and self.check_scope_and_pid(self.scope, name, pid):
84
+ self.data_processor.analyze_pre_forward(name, module, module_input_output)
85
+ return
86
+
88
87
  backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
89
88
  if self.check_scope_and_pid(self.scope, backward_name, pid):
90
89
  self.data_processor.analyze_pre_forward(backward_name, module, module_input_output)
@@ -98,13 +97,14 @@ class DataCollector:
98
97
  self.update_construct(name)
99
98
  if not self.check_scope_and_pid(self.scope, name, pid):
100
99
  return
100
+ if self.config.level == Const.LEVEL_L2:
101
+ self.data_processor.analyze_forward(name, module, module_input_output)
102
+ return
101
103
 
102
104
  if not self.is_inplace(module):
103
105
  data_info = self.data_processor.analyze_forward(name, module, module_input_output)
104
106
  else:
105
107
  data_info = self.data_processor.analyze_forward_inplace(name, module_input_output)
106
- if self.config.level == "L2":
107
- return
108
108
  self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
109
109
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
110
110
 
@@ -114,6 +114,8 @@ class DataCollector:
114
114
  return
115
115
 
116
116
  data_info = self.data_processor.analyze_backward(name, module, module_input_output)
117
+ if self.config.level == Const.LEVEL_L2:
118
+ return
117
119
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
118
120
 
119
121
  def backward_input_data_collect(self, name, module, pid, module_input_output):
@@ -1,7 +1,7 @@
1
1
  # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
6
6
  # You may obtain a copy of the License at
7
7
  #
@@ -15,10 +15,11 @@
15
15
 
16
16
  import inspect
17
17
  import os
18
- from dataclasses import dataclass
18
+ from dataclasses import dataclass, is_dataclass
19
19
  from typing import Tuple, Dict, Optional, Any
20
20
 
21
21
  import numpy as np
22
+
22
23
  from msprobe.core.common.const import Const
23
24
  from msprobe.core.common.log import logger
24
25
  from msprobe.core.common.utils import convert_tuple, CompareException
@@ -101,6 +102,8 @@ class BaseDataProcessor:
101
102
  self.current_iter = 0
102
103
  self._return_forward_new_output = False
103
104
  self._forward_new_output = None
105
+ if hasattr(config, "data_mode"):
106
+ self.allowed_data_mode = self._get_allowed_data_mode(config.data_mode)
104
107
 
105
108
  @property
106
109
  def data_path(self):
@@ -182,6 +185,18 @@ class BaseDataProcessor:
182
185
  def _analyze_numpy(value, numpy_type):
183
186
  return {"type": numpy_type, "value": value}
184
187
 
188
+ @staticmethod
189
+ def _get_allowed_data_mode(data_mode):
190
+ if Const.ALL in data_mode:
191
+ allowed_data_mode = [Const.FORWARD, Const.BACKWARD, Const.INPUT, Const.OUTPUT]
192
+ else:
193
+ allowed_data_mode = list(set(data_mode))
194
+ if Const.FORWARD not in allowed_data_mode and Const.BACKWARD not in allowed_data_mode:
195
+ allowed_data_mode += [Const.FORWARD, Const.BACKWARD]
196
+ if Const.INPUT not in allowed_data_mode and Const.OUTPUT not in allowed_data_mode:
197
+ allowed_data_mode += [Const.INPUT, Const.OUTPUT]
198
+ return allowed_data_mode
199
+
185
200
  @classmethod
186
201
  def get_special_types(cls):
187
202
  return cls.special_type
@@ -194,25 +209,42 @@ class BaseDataProcessor:
194
209
  if isinstance(args, cls.get_special_types()):
195
210
  arg_transform = transform(args, cls._recursive_key_stack)
196
211
  return arg_transform
212
+ elif isinstance(args, tuple) and hasattr(args, '_fields'):
213
+ # namedtuple to dict
214
+ args_dict = {field: getattr(args, field) for field in args._fields}
215
+ return cls.apply_transform_dict(args_dict, transform, depth)
216
+ elif is_dataclass(args):
217
+ # dataclass to dict
218
+ args_dict = {field: getattr(args, field) for field in args.__dataclass_fields__}
219
+ return cls.apply_transform_dict(args_dict, transform, depth)
197
220
  elif isinstance(args, (list, tuple)):
198
- result_list = []
199
- for i, arg in enumerate(args):
200
- cls._recursive_key_stack.append(str(i))
201
- result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
202
- cls._recursive_key_stack.pop()
221
+ result_list = cls.apply_transform_list(args, transform, depth)
203
222
  return type(args)(result_list)
204
223
  elif isinstance(args, dict):
205
- result_dict = {}
206
- for k, arg in args.items():
207
- cls._recursive_key_stack.append(str(k))
208
- result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
209
- cls._recursive_key_stack.pop()
210
- return result_dict
224
+ return cls.apply_transform_dict(args, transform, depth)
211
225
  elif args is not None:
212
226
  logger.warning(f"Data type {type(args)} is not supported.")
213
227
  return None
214
228
  else:
215
229
  return None
230
+
231
+ @classmethod
232
+ def apply_transform_dict(cls, args, transform, depth):
233
+ result_dict = {}
234
+ for k, arg in args.items():
235
+ cls._recursive_key_stack.append(str(k))
236
+ result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
237
+ cls._recursive_key_stack.pop()
238
+ return result_dict
239
+
240
+ @classmethod
241
+ def apply_transform_list(cls, args, transform, depth):
242
+ result_list = []
243
+ for i, arg in enumerate(args):
244
+ cls._recursive_key_stack.append(str(i))
245
+ result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
246
+ cls._recursive_key_stack.pop()
247
+ return result_list
216
248
 
217
249
  def if_return_forward_new_output(self):
218
250
  return self._return_forward_new_output
@@ -239,9 +271,7 @@ class BaseDataProcessor:
239
271
  Return:
240
272
  bool: True if the parameters are in data_mode or data_mode is all, False otherwise.
241
273
  """
242
- return (Const.ALL in self.config.data_mode or
243
- forward_backward in self.config.data_mode or
244
- input_output in self.config.data_mode)
274
+ return forward_backward in self.allowed_data_mode and input_output in self.allowed_data_mode
245
275
 
246
276
  def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
247
277
  pass
@@ -41,7 +41,7 @@ class MindsporeDataProcessor(BaseDataProcessor):
41
41
  @staticmethod
42
42
  def get_md5_for_tensor(x):
43
43
  x = convert_bf16_to_fp32(x)
44
- tensor_bytes = x.asnumpy().tobytes()
44
+ tensor_bytes = x.contiguous().asnumpy().tobytes()
45
45
  crc32_hash = zlib.crc32(tensor_bytes)
46
46
  return f"{crc32_hash:08x}"
47
47
 
@@ -58,19 +58,19 @@ class MindsporeDataProcessor(BaseDataProcessor):
58
58
  if data.numel() == 0:
59
59
  return tensor_stat
60
60
  elif data.dtype == ms.bool_:
61
- data_np = data.asnumpy()
61
+ data_np = data.contiguous().asnumpy()
62
62
  tensor_stat.max = np.max(data_np).item()
63
63
  tensor_stat.min = np.min(data_np).item()
64
64
  elif not data.shape:
65
65
  tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
66
66
  elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
67
- data_abs = np.abs(data.asnumpy())
67
+ data_abs = np.abs(data.contiguous().asnumpy())
68
68
  tensor_stat.max = np.max(data_abs).item()
69
69
  tensor_stat.min = np.min(data_abs).item()
70
70
  tensor_stat.mean = np.mean(data_abs).item()
71
71
  tensor_stat.norm = np.linalg.norm(data_abs).item()
72
72
  else:
73
- if not ops.is_floating_point(data):
73
+ if not ops.is_floating_point(data) or data.dtype == ms.float64:
74
74
  data = data.to(ms.float32)
75
75
  api_register.norm_inner_op_set_ori_func()
76
76
  get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max)
@@ -13,19 +13,24 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import hashlib
16
17
  import zlib
17
18
  from dataclasses import asdict
18
19
  from typing import List
19
20
 
20
21
  import numpy as np
21
22
  import torch
23
+ from torch import distributed as dist
24
+
22
25
  from msprobe.core.common.const import Const
23
26
  from msprobe.core.common.file_utils import path_len_exceeds_limit
24
27
  from msprobe.core.common.log import logger
28
+ from msprobe.core.common.utils import convert_tuple
25
29
  from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
26
30
  ModuleForwardInputsOutputs, TensorStatInfo
27
31
  from msprobe.pytorch.common.utils import save_pt, load_pt
28
32
  from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
33
+ from msprobe.core.common.utils import recursion_depth_decorator
29
34
 
30
35
  is_gpu = False
31
36
  try:
@@ -35,7 +40,13 @@ except ImportError:
35
40
 
36
41
 
37
42
  class PytorchDataProcessor(BaseDataProcessor):
38
- pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor)
43
+ pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor, torch.memory_format, dist.ProcessGroup)
44
+ memory_format = {
45
+ torch.contiguous_format: "contiguous_format",
46
+ torch.channels_last: "channels_last",
47
+ torch.channels_last_3d: "channels_last_3d",
48
+ torch.preserve_format: "preserve_format"
49
+ }
39
50
 
40
51
  def __init__(self, config, data_writer):
41
52
  super().__init__(config, data_writer)
@@ -79,8 +90,8 @@ class PytorchDataProcessor(BaseDataProcessor):
79
90
  if data_clone.numel() == 0:
80
91
  return tensor_stat
81
92
  elif data_clone.dtype == torch.bool:
82
- tensor_stat.max = True in data_clone
83
- tensor_stat.min = False not in data_clone
93
+ tensor_stat.max = torch._C._VariableFunctionsClass.any(data_clone).item()
94
+ tensor_stat.min = torch._C._VariableFunctionsClass.all(data_clone).item()
84
95
  elif not data_clone.shape:
85
96
  tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.item()
86
97
  elif torch.is_complex(data_clone):
@@ -104,20 +115,46 @@ class PytorchDataProcessor(BaseDataProcessor):
104
115
  data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
105
116
  if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
106
117
  return float('nan')
118
+
107
119
  finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
108
120
  if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
109
- finite_values = data_clone[finite_mask]
121
+ finite_values = getattr(torch._C._TensorBase, "__getitem__")(data_clone, finite_mask)
110
122
  return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
111
123
  torch._C._VariableFunctionsClass.min(finite_values).item()
112
124
  else:
113
- data_no_nan = data_clone[~data_nan]
125
+ data_no_nan = getattr(torch._C._TensorBase, "__getitem__")(data_clone, ~data_nan)
114
126
  return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
115
127
  torch._C._VariableFunctionsClass.min(data_no_nan).item()
116
128
 
129
+ @staticmethod
130
+ def process_group_hash(arg):
131
+ group_ranks = dist.get_process_group_ranks(arg)
132
+ group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest()
133
+ return group_ranks_hash
134
+
117
135
  @staticmethod
118
136
  def _analyze_torch_size(arg):
119
137
  return {"type": "torch.Size", "value": list(arg)}
120
138
 
139
+ @staticmethod
140
+ def _analyze_memory_format(arg):
141
+ # 获取内存格式
142
+ format_type = PytorchDataProcessor.memory_format.get(arg)
143
+
144
+ return {"type": "torch.memory_format", "format": format_type}
145
+
146
+ @staticmethod
147
+ def _analyze_process_group(arg):
148
+ group_info = {"type": "torch.ProcessGroup"}
149
+ try:
150
+ group_ranks = dist.get_process_group_ranks(arg)
151
+ group_info.update({"group_ranks": group_ranks})
152
+ group_id = PytorchDataProcessor.process_group_hash(arg)
153
+ group_info.update({"group_id": group_id})
154
+ except Exception as e:
155
+ logger.warning(f"Failed to get process group(id: {group_id}) ranks info with error info: {e}.")
156
+ return group_info
157
+
121
158
  @classmethod
122
159
  def get_special_types(cls):
123
160
  return super().get_special_types() + cls.pytorch_special_type
@@ -127,6 +164,10 @@ class PytorchDataProcessor(BaseDataProcessor):
127
164
  return self.torch_object_key[suffix_stack[-1]](element)
128
165
  if isinstance(element, torch.Size):
129
166
  return self._analyze_torch_size(element)
167
+ if isinstance(element, torch.memory_format):
168
+ return self._analyze_memory_format(element)
169
+ if isinstance(element, dist.ProcessGroup):
170
+ return self._analyze_process_group(element)
130
171
  converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
131
172
  if converted_numpy is not element:
132
173
  return self._analyze_numpy(converted_numpy, numpy_type)
@@ -320,64 +361,120 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor):
320
361
 
321
362
 
322
363
  class KernelDumpDataProcessor(PytorchDataProcessor):
323
- forward_init_status = False
324
- multi_output_apis = ["_sort_", "npu_flash_attention"]
325
-
326
364
  def __init__(self, config, data_writer):
327
365
  super().__init__(config, data_writer)
366
+ self.enable_kernel_dump = True
367
+ self.is_found_output_tensor = False
368
+ self.is_found_grad_input_tensor = False
369
+ self.forward_args = None
370
+ self.forward_kwargs = None
371
+ self.forward_output_tensor = None
372
+ self.grad_input_tensor = None
373
+
374
+ @staticmethod
375
+ def start_kernel_dump(config_path):
376
+ torch_npu.npu.synchronize()
377
+ torch_npu.npu.init_dump()
378
+ torch_npu.npu.set_dump(config_path)
379
+ torch_npu.npu.synchronize()
380
+
381
+ @staticmethod
382
+ def stop_kernel_dump():
383
+ torch_npu.npu.synchronize()
384
+ torch_npu.npu.finalize_dump()
385
+ torch_npu.npu.synchronize()
386
+
387
+ @staticmethod
388
+ def _print_unsupported_log(api_name):
389
+ logger.warning(f"The kernel dump does not support the {api_name} API.")
390
+
391
+ def analyze_pre_forward(self, name, module, module_input_output):
392
+ if not self.enable_kernel_dump:
393
+ return
394
+ if is_gpu:
395
+ logger.warning("The current environment is not a complete NPU environment, and kernel dump cannot be used.")
396
+ self.enable_kernel_dump = False
397
+ return
398
+
399
+ if self.config.is_backward_kernel_dump:
400
+ self.forward_args = self.clone_and_detach_tensor(module_input_output.args)
401
+ self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs)
402
+ try:
403
+ output = module.forward(*self.forward_args, **self.forward_kwargs)
404
+ except Exception:
405
+ self._print_unsupported_log(name)
406
+ self.enable_kernel_dump = False
407
+ return
408
+
409
+ self.analyze_element(convert_tuple(output))
410
+ if not self.is_found_output_tensor:
411
+ self._print_unsupported_log(name)
412
+ self.enable_kernel_dump = False
413
+ return
414
+ self.start_kernel_dump(self.config.kernel_config_path)
328
415
 
329
416
  def analyze_forward(self, name, module, module_input_output):
330
- if self.config.is_forward_acl_dump:
331
- self.forward_acl_dump(name, module, module_input_output)
417
+ if not self.enable_kernel_dump:
418
+ return
419
+ if self.config.is_backward_kernel_dump:
420
+ return
421
+ self.enable_kernel_dump = False
422
+ self.stop_kernel_dump()
423
+ logger.info(f"The kernel data of {name} is dumped successfully.")
424
+
425
+ def analyze_backward(self, name, module, module_input_output):
426
+ if not self.enable_kernel_dump:
427
+ return
428
+ self.enable_kernel_dump = False
429
+
430
+ self.analyze_element(module_input_output.grad_input)
431
+ if not self.is_found_grad_input_tensor:
432
+ self._print_unsupported_log(name)
433
+ return
434
+ self.start_kernel_dump(self.config.kernel_config_path)
435
+
436
+ try:
437
+ self.forward_output_tensor.backward(self.grad_input_tensor, retain_graph=True)
438
+ except Exception:
439
+ self._print_unsupported_log(name)
440
+ self.stop_kernel_dump()
441
+ return
442
+
443
+ self.stop_kernel_dump()
444
+ logger.info(f"The kernel data of {name} is dumped successfully.")
445
+
446
+ @recursion_depth_decorator("KernelDump: KernelDumpDataProcessor.clone_and_detach_tensor")
447
+ def clone_and_detach_tensor(self, input_params):
448
+ if isinstance(input_params, torch.Tensor):
449
+ if input_params.requires_grad:
450
+ return input_params.clone().detach().requires_grad_()
451
+ return input_params.clone()
452
+ elif isinstance(input_params, tuple):
453
+ return tuple(self.clone_and_detach_tensor(x) for x in input_params)
454
+ elif isinstance(input_params, list):
455
+ return list(self.clone_and_detach_tensor(x) for x in input_params)
456
+ elif isinstance(input_params, dict):
457
+ return {k: self.clone_and_detach_tensor(v) for k, v in input_params.items()}
332
458
  else:
333
- self.dump_mode_backward_acl_dump(name, module, module_input_output)
334
-
335
- def forward_acl_dump(self, name, module, module_input_output):
336
- if not KernelDumpDataProcessor.forward_init_status:
337
- KernelDumpDataProcessor.forward_init_status = True
338
- torch_npu.npu.synchronize()
339
- torch_npu.npu.init_dump()
340
- torch_npu.npu.set_dump(self.config.acl_config)
341
- torch_npu.npu.synchronize()
342
- if self.op_need_trigger(name):
343
- module.forward(*module_input_output.args, **module_input_output.kwargs).cpu()
344
- else:
345
- module.forward(*module_input_output.args, **module_input_output.kwargs)
346
- torch_npu.npu.synchronize()
347
- torch_npu.npu.finalize_dump()
348
- torch_npu.npu.synchronize()
349
- KernelDumpDataProcessor.forward_init_status = False
350
- logger.info("Dump %s op file." % name)
351
-
352
- def acl_backward_dump_status(self, output, grad, module_name):
353
- if isinstance(output, torch.Tensor):
354
- output.backward(grad, retain_graph=True)
355
- return True
459
+ return input_params
356
460
 
357
- for api_name in KernelDumpDataProcessor.multi_output_apis:
358
- if api_name in module_name:
359
- output[0].backward(grad, retain_graph=True)
360
- return True
361
- return False
461
+ def analyze_single_element(self, element, suffix_stack):
462
+ if isinstance(element, torch.Tensor):
463
+ if not self.is_found_output_tensor:
464
+ if element.requires_grad:
465
+ self.forward_output_tensor = element
466
+ self.is_found_output_tensor = True
467
+ return {}
468
+ if not self.is_found_grad_input_tensor:
469
+ self.grad_input_tensor = element.clone()
470
+ self.is_found_grad_input_tensor = True
471
+ return {}
362
472
 
363
- def dump_mode_backward_acl_dump(self, name, module, module_input_output):
364
- grad_path = self.config.backward_input.get(name)
365
- if not KernelDumpDataProcessor.forward_init_status:
366
- KernelDumpDataProcessor.forward_init_status = True
367
- output = module.forward(*module_input_output.args, **module_input_output.kwargs)
368
- pt = load_pt(grad_path)
369
- grad = pt.to("npu").requires_grad_()
370
- torch_npu.npu.init_dump()
371
- torch_npu.npu.set_dump(self.config.acl_config)
372
- torch_npu.npu.synchronize()
373
- if not self.acl_backward_dump_status(output, grad, name):
374
- logger.warning("The output of {} is not of tensor type and cannot be automatically derived. "
375
- "you can manually construct a single API backward case for ACL dump.".format(
376
- name))
377
- torch_npu.npu.synchronize()
378
- torch_npu.npu.finalize_dump()
379
- KernelDumpDataProcessor.forward_init_status = False
380
- logger.info("Dump %s op file." % name)
381
-
382
- def op_need_trigger(self, module_name):
383
- return 'Tensor.__getitem__.' in module_name
473
+ def reset_status(self):
474
+ self.enable_kernel_dump = True
475
+ self.is_found_output_tensor = False
476
+ self.is_found_grad_input_tensor = False
477
+ self.forward_args = None
478
+ self.forward_kwargs = None
479
+ self.forward_output_tensor = None
480
+ self.grad_input_tensor = None