mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +14 -19
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +155 -6
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +3 -0
  10. msprobe/core/common/utils.py +28 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +380 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/multiprocessing_compute.py +2 -2
  22. msprobe/core/compare/npy_compare.py +109 -147
  23. msprobe/core/compare/utils.py +189 -69
  24. msprobe/core/data_dump/data_collector.py +51 -21
  25. msprobe/core/data_dump/data_processor/base.py +38 -20
  26. msprobe/core/data_dump/data_processor/factory.py +5 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
  29. msprobe/core/data_dump/json_writer.py +29 -1
  30. msprobe/core/data_dump/scope.py +19 -18
  31. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  32. msprobe/core/overflow_check/checker.py +1 -1
  33. msprobe/core/overflow_check/utils.py +1 -1
  34. msprobe/docs/01.installation.md +96 -17
  35. msprobe/docs/02.config_introduction.md +5 -5
  36. msprobe/docs/05.data_dump_PyTorch.md +91 -61
  37. msprobe/docs/06.data_dump_MindSpore.md +57 -19
  38. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  39. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
  40. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  41. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  42. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  43. msprobe/docs/19.monitor.md +120 -27
  44. msprobe/docs/21.visualization_PyTorch.md +115 -35
  45. msprobe/docs/22.visualization_MindSpore.md +138 -41
  46. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  47. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  48. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  49. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  50. msprobe/docs/27.dump_json_instruction.md +521 -0
  51. msprobe/docs/FAQ.md +26 -2
  52. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  53. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  54. msprobe/docs/img/merge_result.png +0 -0
  55. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  56. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  57. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  58. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  59. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  60. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  61. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  63. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  64. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  65. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  66. msprobe/docs/visualization/GPTModel.png +0 -0
  67. msprobe/docs/visualization/ParallelMLP.png +0 -0
  68. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  69. msprobe/docs/visualization/mapping.png +0 -0
  70. msprobe/docs/visualization/mapping1.png +0 -0
  71. msprobe/docs/visualization/module_name.png +0 -0
  72. msprobe/docs/visualization/module_name1.png +0 -0
  73. msprobe/docs/visualization/no_mapping.png +0 -0
  74. msprobe/docs/visualization/no_mapping1.png +0 -0
  75. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  76. msprobe/docs/visualization/top_layer.png +0 -0
  77. msprobe/mindspore/__init__.py +10 -0
  78. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
  79. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  80. msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
  81. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  82. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  83. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  84. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  85. msprobe/mindspore/code_mapping/bind.py +264 -0
  86. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  87. msprobe/mindspore/code_mapping/graph.py +49 -0
  88. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  89. msprobe/mindspore/code_mapping/main.py +24 -0
  90. msprobe/mindspore/code_mapping/processor.py +34 -0
  91. msprobe/mindspore/common/const.py +3 -1
  92. msprobe/mindspore/common/utils.py +50 -5
  93. msprobe/mindspore/compare/distributed_compare.py +0 -2
  94. msprobe/mindspore/compare/ms_compare.py +105 -63
  95. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  96. msprobe/mindspore/debugger/debugger_config.py +3 -0
  97. msprobe/mindspore/debugger/precision_debugger.py +81 -12
  98. msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
  99. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  100. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  101. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  102. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  103. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  104. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  105. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  106. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  107. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  108. msprobe/mindspore/grad_probe/hook.py +13 -4
  109. msprobe/mindspore/mindtorch/__init__.py +18 -0
  110. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  111. msprobe/mindspore/ms_config.py +5 -1
  112. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  113. msprobe/mindspore/service.py +267 -101
  114. msprobe/msprobe.py +24 -3
  115. msprobe/pytorch/__init__.py +7 -6
  116. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  117. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  123. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  124. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
  125. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  126. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  127. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  128. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  129. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  130. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  131. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  132. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  133. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  134. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  135. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  136. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  140. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  141. msprobe/pytorch/common/parse_json.py +2 -1
  142. msprobe/pytorch/common/utils.py +45 -2
  143. msprobe/pytorch/compare/distributed_compare.py +17 -29
  144. msprobe/pytorch/compare/pt_compare.py +40 -20
  145. msprobe/pytorch/debugger/debugger_config.py +27 -12
  146. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  147. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  148. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  149. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
  150. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  151. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  152. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  153. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  154. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  155. msprobe/pytorch/hook_module/__init__.py +1 -1
  156. msprobe/pytorch/hook_module/hook_module.py +14 -11
  157. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  158. msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
  159. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  160. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  161. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  162. msprobe/pytorch/monitor/anomaly_detect.py +107 -22
  163. msprobe/pytorch/monitor/csv2tb.py +166 -0
  164. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  165. msprobe/pytorch/monitor/features.py +3 -3
  166. msprobe/pytorch/monitor/module_hook.py +483 -277
  167. msprobe/pytorch/monitor/module_metric.py +27 -48
  168. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  169. msprobe/pytorch/monitor/optimizer_collect.py +52 -14
  170. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  171. msprobe/pytorch/monitor/utils.py +77 -6
  172. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  173. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  174. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  175. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  176. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  177. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  178. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  179. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  180. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  181. msprobe/pytorch/service.py +176 -106
  182. msprobe/visualization/builder/graph_builder.py +62 -5
  183. msprobe/visualization/builder/msprobe_adapter.py +24 -2
  184. msprobe/visualization/compare/graph_comparator.py +64 -14
  185. msprobe/visualization/compare/mode_adapter.py +1 -15
  186. msprobe/visualization/graph/base_node.py +12 -17
  187. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  188. msprobe/visualization/graph/graph.py +9 -0
  189. msprobe/visualization/graph_service.py +97 -23
  190. msprobe/visualization/utils.py +14 -29
  191. msprobe/pytorch/functional/module_dump.py +0 -84
  192. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  193. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
  194. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
  195. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  196. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  197. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -56,7 +56,7 @@ class DataProcessorFactory:
56
56
  FreeBenchmarkDataProcessor as PytorchFreeBenchmarkDataProcessor,
57
57
  KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
58
58
  )
59
- from msprobe.pytorch.module_processer import ModuleProcesser
59
+ from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
60
60
  cls.register_processor(Const.PT_FRAMEWORK, Const.STATISTICS, PytorchStatisticsDataProcessor)
61
61
  cls.register_processor(Const.PT_FRAMEWORK, Const.TENSOR, PytorchTensorDataProcessor)
62
62
  cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
@@ -67,10 +67,12 @@ class DataProcessorFactory:
67
67
  from msprobe.core.data_dump.data_processor.mindspore_processor import (
68
68
  StatisticsDataProcessor as MindsporeStatisticsDataProcessor,
69
69
  TensorDataProcessor as MindsporeTensorDataProcessor,
70
- OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor
70
+ OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor,
71
+ KernelDumpDataProcessor as MindsporeKernelDumpDataProcessor
71
72
  )
72
73
  from msprobe.mindspore.cell_processor import CellProcessor
73
74
  cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor)
74
75
  cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
75
76
  cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
77
+ cls.register_processor(Const.MS_FRAMEWORK, Const.KERNEL_DUMP, MindsporeKernelDumpDataProcessor)
76
78
  cls.register_module_processor(Const.MS_FRAMEWORK, CellProcessor)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
1
+ # Copyright 2024-2025 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@
16
16
  import zlib
17
17
 
18
18
  import mindspore as ms
19
- from mindspore import mint, ops
19
+ from mindspore import mint, ops, hal
20
20
  from mindspore._c_expression.typing import Number
21
21
  import numpy as np
22
22
 
@@ -28,6 +28,12 @@ from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_
28
28
  from msprobe.mindspore.common.log import logger
29
29
  from msprobe.mindspore.dump.hook_cell.api_registry import api_register
30
30
 
31
+ has_adump = True
32
+ try:
33
+ from msprobe.lib import _msprobe_c
34
+ except ImportError:
35
+ has_adump = False
36
+
31
37
 
32
38
  class MindsporeDataProcessor(BaseDataProcessor):
33
39
  mindspore_special_type = tuple([ms.Tensor, Number])
@@ -37,11 +43,12 @@ class MindsporeDataProcessor(BaseDataProcessor):
37
43
  self.mindspore_object_key = {
38
44
  "dtype": self.analyze_dtype_in_kwargs
39
45
  }
46
+ self._async_dump_cache = {}
40
47
 
41
48
  @staticmethod
42
49
  def get_md5_for_tensor(x):
43
50
  x = convert_bf16_to_fp32(x)
44
- tensor_bytes = x.contiguous().asnumpy().tobytes()
51
+ tensor_bytes = x.asnumpy().tobytes()
45
52
  crc32_hash = zlib.crc32(tensor_bytes)
46
53
  return f"{crc32_hash:08x}"
47
54
 
@@ -49,22 +56,17 @@ class MindsporeDataProcessor(BaseDataProcessor):
49
56
  def analyze_dtype_in_kwargs(element):
50
57
  return {"type": "mindspore.dtype", "value": str(element)}
51
58
 
52
- @classmethod
53
- def get_special_types(cls):
54
- return super().get_special_types() + cls.mindspore_special_type
55
-
56
- def get_stat_info(self, data):
59
+ @staticmethod
60
+ def get_stat_info_sync(data):
57
61
  tensor_stat = TensorStatInfo()
58
- if data.numel() == 0:
59
- return tensor_stat
60
- elif data.dtype == ms.bool_:
61
- data_np = data.contiguous().asnumpy()
62
+ if data.dtype == ms.bool_:
63
+ data_np = data.asnumpy()
62
64
  tensor_stat.max = np.max(data_np).item()
63
65
  tensor_stat.min = np.min(data_np).item()
64
66
  elif not data.shape:
65
67
  tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
66
68
  elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
67
- data_abs = np.abs(data.contiguous().asnumpy())
69
+ data_abs = np.abs(data.asnumpy())
68
70
  tensor_stat.max = np.max(data_abs).item()
69
71
  tensor_stat.min = np.min(data_abs).item()
70
72
  tensor_stat.mean = np.mean(data_abs).item()
@@ -87,6 +89,47 @@ class MindsporeDataProcessor(BaseDataProcessor):
87
89
  api_register.norm_inner_op_set_hook_func()
88
90
  return tensor_stat
89
91
 
92
+ @staticmethod
93
+ def get_stat_info_async(data):
94
+ tensor_stat = TensorStatInfo()
95
+ stack_method = api_register.functional_ori_attr.get("stack", ms.ops.stack)
96
+ if data.dtype == ms.complex64 or data.dtype == ms.complex128:
97
+ logger.warning("Async dump do not support complex data!")
98
+ return tensor_stat
99
+ elif data.dtype == ms.bool_:
100
+ tensor_stat.stack_tensor_stat = (["Max", "Min"], stack_method([data.any(), data.all()]))
101
+ elif not data.shape:
102
+ tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method([data, data, data, data]))
103
+ else:
104
+ if not ops.is_floating_point(data) or data.dtype == ms.float64:
105
+ data = data.to(ms.float32)
106
+ api_register.norm_inner_op_set_ori_func()
107
+ get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max)
108
+ get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min)
109
+ get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean)
110
+ if hasattr(mint, "norm"):
111
+ get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm)
112
+ else:
113
+ get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
114
+ tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method(
115
+ [get_max_value(data), get_min_value(data), get_mean_value(data), get_norm_value(data)]))
116
+ api_register.norm_inner_op_set_hook_func()
117
+ return tensor_stat
118
+
119
+ @classmethod
120
+ def get_special_types(cls):
121
+ return super().get_special_types() + cls.mindspore_special_type
122
+
123
+ def get_stat_info(self, data):
124
+ tensor_stat = TensorStatInfo()
125
+ if data.numel() == 0:
126
+ return tensor_stat
127
+ else:
128
+ if self.config.async_dump:
129
+ return MindsporeDataProcessor.get_stat_info_async(data)
130
+ else:
131
+ return MindsporeDataProcessor.get_stat_info_sync(data)
132
+
90
133
  def analyze_single_element(self, element, suffix_stack):
91
134
  if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
92
135
  return self.mindspore_object_key[suffix_stack[-1]](element)
@@ -107,13 +150,17 @@ class MindsporeDataProcessor(BaseDataProcessor):
107
150
  tensor_json = {
108
151
  'type': 'mindspore.Tensor',
109
152
  'dtype': str(tensor.dtype),
110
- 'shape': tensor.shape,
111
- 'Max': self.transfer_type(tensor_stat.max),
112
- 'Min': self.transfer_type(tensor_stat.min),
113
- 'Mean': self.transfer_type(tensor_stat.mean),
114
- 'Norm': self.transfer_type(tensor_stat.norm),
153
+ 'shape': tensor.shape
115
154
  }
116
- if self.config.summary_mode == Const.MD5:
155
+
156
+ if tensor_stat.stack_tensor_stat is None:
157
+ tensor_json.update({'Max': self.transfer_type(tensor_stat.max)})
158
+ tensor_json.update({'Min': self.transfer_type(tensor_stat.min)})
159
+ tensor_json.update({'Mean': self.transfer_type(tensor_stat.mean)})
160
+ tensor_json.update({'Norm': self.transfer_type(tensor_stat.norm)})
161
+ else:
162
+ tensor_json.update({'tensor_stat': tensor_stat.stack_tensor_stat})
163
+ if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
117
164
  tensor_md5 = self.get_md5_for_tensor(tensor)
118
165
  tensor_json.update({Const.MD5: tensor_md5})
119
166
  return tensor_json
@@ -124,11 +171,19 @@ class StatisticsDataProcessor(MindsporeDataProcessor):
124
171
 
125
172
 
126
173
  class TensorDataProcessor(MindsporeDataProcessor):
174
+ def dump_async_data(self):
175
+ for file_path, tensor in self._async_dump_cache.items():
176
+ save_tensor_as_npy(tensor, file_path)
177
+ self._async_dump_cache.clear()
178
+
127
179
  def _analyze_tensor(self, tensor, suffix):
128
180
  dump_data_name, file_path = self.get_save_file_path(suffix)
129
181
  single_arg = super()._analyze_tensor(tensor, suffix)
130
182
  single_arg.update({"data_name": dump_data_name})
131
- save_tensor_as_npy(tensor, file_path)
183
+ if self.config.async_dump:
184
+ self._async_dump_cache[file_path] = tensor.copy()
185
+ else:
186
+ save_tensor_as_npy(tensor, file_path)
132
187
  return single_arg
133
188
 
134
189
 
@@ -138,6 +193,7 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
138
193
  def __init__(self, config, data_writer):
139
194
  super().__init__(config, data_writer)
140
195
  self.has_overflow = False
196
+ self.cached_api_info = {}
141
197
  self.cached_tensors_and_file_paths = {}
142
198
  self.real_overflow_nums = 0
143
199
  self.overflow_nums = config.overflow_nums
@@ -150,6 +206,20 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
150
206
  return True
151
207
  return False
152
208
 
209
+ def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
210
+ self.has_overflow = False
211
+ self.cached_api_info = super().analyze_forward_input(name, module, module_input_output)
212
+ return None
213
+
214
+ def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
215
+ api_info_struct = super().analyze_forward_output(name, module, module_input_output)
216
+ if name in self.cached_api_info and name in api_info_struct:
217
+ self.cached_api_info[name].update(api_info_struct[name])
218
+ elif name in api_info_struct:
219
+ self.cached_api_info = api_info_struct
220
+ self.maybe_save_overflow_data()
221
+ return self.cached_api_info if self.has_overflow else None
222
+
153
223
  def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
154
224
  self.has_overflow = False
155
225
  api_info_struct = super().analyze_forward(name, module, module_input_output)
@@ -161,6 +231,12 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
161
231
  api_info_struct = super().analyze_backward(name, module, module_input_output)
162
232
  self.maybe_save_overflow_data()
163
233
  return api_info_struct if self.has_overflow else None
234
+
235
+ def analyze_params(self, name, param_name, grad):
236
+ self.has_overflow = False
237
+ api_info_struct = super().analyze_params(name, param_name, grad)
238
+ self.maybe_save_overflow_data()
239
+ return api_info_struct if self.has_overflow else None
164
240
 
165
241
  def maybe_save_overflow_data(self):
166
242
  if self.has_overflow:
@@ -190,3 +266,61 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
190
266
  self._analyze_maybe_overflow_tensor(single_arg)
191
267
  single_arg.update({"data_name": dump_data_name})
192
268
  return single_arg
269
+
270
+
271
+ class KernelDumpDataProcessor(MindsporeDataProcessor):
272
+ def __init__(self, config, data_writer):
273
+ super().__init__(config, data_writer)
274
+ self.enable_kernel_dump = True
275
+
276
+ @staticmethod
277
+ def start_kernel_dump(config_path):
278
+ hal.synchronize()
279
+ _msprobe_c.init_dump()
280
+ _msprobe_c.set_dump(config_path)
281
+ hal.synchronize()
282
+
283
+ @staticmethod
284
+ def stop_kernel_dump():
285
+ hal.synchronize()
286
+ _msprobe_c.finalize_dump()
287
+ hal.synchronize()
288
+
289
+ @staticmethod
290
+ def _print_unsupported_log(api_name):
291
+ logger.warning(f"The kernel dump does not support the {api_name} API.")
292
+
293
+ def analyze_forward_input(self, name, module, module_input_output):
294
+ if not self.enable_kernel_dump:
295
+ return
296
+ if not has_adump:
297
+ logger.warning("The current msprobe package does not compile adump, and kernel dump cannot be used.")
298
+ self.enable_kernel_dump = False
299
+ return
300
+ self.start_kernel_dump(self.config.kernel_config_path)
301
+
302
+ def analyze_forward_output(self, name, module, module_input_output):
303
+ if not self.enable_kernel_dump:
304
+ return
305
+ self.enable_kernel_dump = False
306
+ self.stop_kernel_dump()
307
+ logger.info(f"The kernel data of {name} is dumped successfully.")
308
+
309
+ def analyze_backward_input(self, name, module, module_input_output):
310
+ if not self.enable_kernel_dump:
311
+ return
312
+ if not has_adump:
313
+ logger.warning("The current msprobe package does not compile adump, and kernel dump cannot be used.")
314
+ self.enable_kernel_dump = False
315
+ return
316
+ self.start_kernel_dump(self.config.kernel_config_path)
317
+
318
+ def analyze_backward(self, name, module, module_input_output):
319
+ if not self.enable_kernel_dump:
320
+ return
321
+ self.enable_kernel_dump = False
322
+ self.stop_kernel_dump()
323
+ logger.info(f"The kernel data of {name} is dumped successfully.")
324
+
325
+ def reset_status(self):
326
+ self.enable_kernel_dump = True
@@ -54,6 +54,7 @@ class PytorchDataProcessor(BaseDataProcessor):
54
54
  "device": self.analyze_device_in_kwargs,
55
55
  "dtype": self.analyze_dtype_in_kwargs
56
56
  }
57
+ self._async_dump_cache = {}
57
58
 
58
59
  @staticmethod
59
60
  def get_md5_for_tensor(x):
@@ -82,49 +83,80 @@ class PytorchDataProcessor(BaseDataProcessor):
82
83
  return {"type": "torch.dtype", "value": str(element)}
83
84
 
84
85
  @staticmethod
85
- def get_stat_info(data):
86
+ def get_stat_info_async(data):
86
87
  tensor_stat = TensorStatInfo()
87
- if data.is_meta:
88
- return tensor_stat
89
- data_clone = data.detach()
90
- if data_clone.numel() == 0:
88
+ if torch.is_complex(data):
89
+ logger.warning("Async dump do not support complex data!")
91
90
  return tensor_stat
92
- elif data_clone.dtype == torch.bool:
93
- tensor_stat.max = torch._C._VariableFunctionsClass.any(data_clone).item()
94
- tensor_stat.min = torch._C._VariableFunctionsClass.all(data_clone).item()
95
- elif not data_clone.shape:
96
- tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.item()
97
- elif torch.is_complex(data_clone):
98
- data_np = data_clone.cpu().numpy()
91
+ elif data.dtype == torch.bool:
92
+ tensor_stat.stack_tensor_stat = (["Max", "Min"], torch.stack(
93
+ [torch.any(data), torch.all(data)]))
94
+ elif not data.shape:
95
+ tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([data, data, data, data]))
96
+ else:
97
+ if not data.is_floating_point() or data.dtype == torch.float64:
98
+ data = data.float()
99
+ tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([
100
+ torch.max(data),
101
+ torch.min(data),
102
+ torch.mean(data),
103
+ torch.norm(data)
104
+ ]))
105
+ return tensor_stat
106
+
107
+ @staticmethod
108
+ def get_stat_info_sync(data):
109
+ tensor_stat = TensorStatInfo()
110
+ if torch.is_complex(data):
111
+ data_np = data.cpu().numpy()
99
112
  data_abs = np.abs(data_np)
100
113
  tensor_stat.max = np.max(data_abs).item()
101
114
  tensor_stat.min = np.min(data_abs).item()
102
115
  tensor_stat.mean = np.mean(data_abs).item()
116
+ elif data.dtype == torch.bool:
117
+ tensor_stat.max = torch.any(data).item()
118
+ tensor_stat.min = torch.all(data).item()
119
+ elif not data.shape:
120
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
103
121
  else:
104
- if not data_clone.is_floating_point() or data_clone.dtype == torch.float64:
105
- data_clone = data_clone.float()
106
- tensor_stat.max = torch._C._VariableFunctionsClass.max(data_clone).item()
107
- tensor_stat.min = torch._C._VariableFunctionsClass.min(data_clone).item()
108
- tensor_stat.mean = torch._C._VariableFunctionsClass.mean(data_clone).item()
109
- tensor_stat.norm = torch._C._VariableFunctionsClass.norm(data_clone).item()
122
+ if not data.is_floating_point() or data.dtype == torch.float64:
123
+ data = data.float()
124
+ tensor_stat.max = torch.max(data).item()
125
+ tensor_stat.min = torch.min(data).item()
126
+ tensor_stat.mean = torch.mean(data).item()
127
+ tensor_stat.norm = torch.norm(data).item()
110
128
  return tensor_stat
111
129
 
130
+ @staticmethod
131
+ def get_stat_info(data, async_dump=False):
132
+ tensor_stat = TensorStatInfo()
133
+ if data.is_meta:
134
+ return tensor_stat
135
+ data_clone = data.detach()
136
+ if data_clone.numel() == 0:
137
+ return tensor_stat
138
+ else:
139
+ if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump:
140
+ return PytorchDataProcessor.get_stat_info_sync(data_clone)
141
+ else:
142
+ return PytorchDataProcessor.get_stat_info_async(data_clone)
143
+
112
144
  @staticmethod
113
145
  def handle_tensor_extremum_nan_inf(tensor, operator):
114
146
  data_clone = tensor.detach()
115
- data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
116
- if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
147
+ data_nan = torch.isnan(data_clone)
148
+ if int(torch.sum(data_nan)) == data_clone.numel():
117
149
  return float('nan')
118
150
 
119
- finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
120
- if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
121
- finite_values = getattr(torch._C._TensorBase, "__getitem__")(data_clone, finite_mask)
122
- return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
123
- torch._C._VariableFunctionsClass.min(finite_values).item()
151
+ finite_mask = torch.isfinite(data_clone)
152
+ if int(torch.sum(finite_mask)) > 0:
153
+ finite_values = data_clone[finite_mask]
154
+ return torch.max(finite_values).item() if operator == 'max' else \
155
+ torch.min(finite_values).item()
124
156
  else:
125
- data_no_nan = getattr(torch._C._TensorBase, "__getitem__")(data_clone, ~data_nan)
126
- return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
127
- torch._C._VariableFunctionsClass.min(data_no_nan).item()
157
+ data_no_nan = data_clone[~data_nan]
158
+ return torch.max(data_no_nan).item() if operator == 'max' else \
159
+ torch.min(data_no_nan).item()
128
160
 
129
161
  @staticmethod
130
162
  def process_group_hash(arg):
@@ -132,6 +164,10 @@ class PytorchDataProcessor(BaseDataProcessor):
132
164
  group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest()
133
165
  return group_ranks_hash
134
166
 
167
+ @staticmethod
168
+ def is_distributed_op(module):
169
+ return getattr(module, "op_is_distributed", False)
170
+
135
171
  @staticmethod
136
172
  def _analyze_torch_size(arg):
137
173
  return {"type": "torch.Size", "value": list(arg)}
@@ -177,26 +213,35 @@ class PytorchDataProcessor(BaseDataProcessor):
177
213
  return self._analyze_builtin(element)
178
214
  return {}
179
215
 
216
+ def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
217
+ if self.is_distributed_op(module):
218
+ module_input_output.update_output_with_args_and_kwargs()
219
+ return super().analyze_forward_output(name, module, module_input_output)
220
+
180
221
  def _analyze_tensor(self, tensor, suffix):
181
- tensor_stat = self.get_stat_info(tensor)
222
+ tensor_stat = self.get_stat_info(tensor, self.config.async_dump)
182
223
  tensor_json = {}
183
224
  tensor_json.update({'type': 'torch.Tensor'})
184
225
  tensor_json.update({'dtype': str(tensor.dtype)})
185
226
  tensor_json.update({"shape": tensor.shape})
186
- tensor_json.update({"Max": tensor_stat.max})
187
- tensor_json.update({"Min": tensor_stat.min})
188
- tensor_json.update({"Mean": tensor_stat.mean})
189
- tensor_json.update({"Norm": tensor_stat.norm})
190
- tensor_json.update({"requires_grad": tensor.requires_grad})
191
-
192
- if tensor_stat.max is not None:
193
- if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
194
- tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
195
- if tensor_stat.min is not None:
196
- if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
197
- tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
198
-
199
- if self.config.summary_mode == Const.MD5:
227
+ if tensor_stat.stack_tensor_stat is None:
228
+ tensor_json.update({"Max": tensor_stat.max})
229
+ tensor_json.update({"Min": tensor_stat.min})
230
+ tensor_json.update({"Mean": tensor_stat.mean})
231
+ tensor_json.update({"Norm": tensor_stat.norm})
232
+ tensor_json.update({"requires_grad": tensor.requires_grad})
233
+ if tensor_stat.max is not None:
234
+ if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
235
+ tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
236
+ if tensor_stat.min is not None:
237
+ if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
238
+ tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
239
+
240
+ else:
241
+ tensor_json.update({"requires_grad": tensor.requires_grad})
242
+ tensor_json.update({"tensor_stat": tensor_stat.stack_tensor_stat})
243
+
244
+ if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
200
245
  tensor_md5 = self.get_md5_for_tensor(tensor)
201
246
  tensor_json.update({Const.MD5: tensor_md5})
202
247
  return tensor_json
@@ -207,12 +252,20 @@ class StatisticsDataProcessor(PytorchDataProcessor):
207
252
 
208
253
 
209
254
  class TensorDataProcessor(PytorchDataProcessor):
255
+ def dump_async_data(self):
256
+ for file_path, tensor in self._async_dump_cache.items():
257
+ save_pt(tensor.contiguous(), file_path)
258
+ self._async_dump_cache.clear()
259
+
210
260
  def _analyze_tensor(self, tensor, suffix):
211
261
  dump_data_name, file_path = self.get_save_file_path(suffix)
212
- saved_tensor = tensor.clone().contiguous().detach()
213
- save_pt(saved_tensor, file_path)
214
262
  single_arg = super()._analyze_tensor(tensor, suffix)
215
263
  single_arg.update({"data_name": dump_data_name})
264
+ if self.config.async_dump:
265
+ self._async_dump_cache[file_path] = tensor.clone().detach()
266
+ else:
267
+ saved_tensor = tensor.clone().contiguous().detach()
268
+ save_pt(saved_tensor, file_path)
216
269
  return single_arg
217
270
 
218
271
 
@@ -223,7 +276,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
223
276
  super().__init__(config, data_writer)
224
277
  self.has_overflow = False
225
278
  self.support_inf_nan = None
226
- self.cached_inplace_api_info = {}
279
+ self.cached_api_info = {}
227
280
  self.cached_tensors_and_file_paths = {}
228
281
  self.bits_for_overflow = 8
229
282
  self.real_overflow_nums = 0
@@ -237,21 +290,21 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
237
290
  return True
238
291
  return False
239
292
 
240
- def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
293
+ def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
241
294
  self.has_overflow = False
242
295
  self._is_support_inf_nan()
243
- self.cached_inplace_api_info = super().analyze_pre_forward_inplace(name, module_input_output)
296
+ self.cached_api_info = super().analyze_forward_input(name, module, module_input_output)
244
297
  return None
245
298
 
246
- def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
299
+ def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
247
300
  self._is_support_inf_nan()
248
- api_info_struct = super().analyze_forward_inplace(name, module_input_output)
249
- if name in self.cached_inplace_api_info and name in api_info_struct:
250
- self.cached_inplace_api_info[name].update(api_info_struct[name])
301
+ api_info_struct = super().analyze_forward_output(name, module, module_input_output)
302
+ if name in self.cached_api_info and name in api_info_struct:
303
+ self.cached_api_info[name].update(api_info_struct[name])
251
304
  elif name in api_info_struct:
252
- self.cached_inplace_api_info = api_info_struct
305
+ self.cached_api_info = api_info_struct
253
306
  self.handle_overflow()
254
- return self.cached_inplace_api_info if self.has_overflow else None
307
+ return self.cached_api_info if self.has_overflow else None
255
308
 
256
309
  def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
257
310
  self.has_overflow = False
@@ -266,6 +319,13 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
266
319
  api_info_struct = super().analyze_backward(name, module, module_input_output)
267
320
  self.handle_overflow()
268
321
  return api_info_struct if self.has_overflow else None
322
+
323
+ def analyze_params(self, name, param_name, grad):
324
+ self.has_overflow = False
325
+ self._is_support_inf_nan()
326
+ api_info_struct = super().analyze_params(name, param_name, grad)
327
+ self.handle_overflow()
328
+ return api_info_struct if self.has_overflow else None
269
329
 
270
330
  def handle_overflow(self):
271
331
  if not self.support_inf_nan:
@@ -340,10 +400,10 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor):
340
400
  )
341
401
  return
342
402
 
343
- def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
403
+ def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
344
404
  self.checker.pre_forward(name, module, self, module_input_output.args, module_input_output.kwargs)
345
405
 
346
- def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
406
+ def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
347
407
  new_output, unequal_rows = self.checker.forward(
348
408
  name,
349
409
  module,
@@ -388,7 +448,7 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
388
448
  def _print_unsupported_log(api_name):
389
449
  logger.warning(f"The kernel dump does not support the {api_name} API.")
390
450
 
391
- def analyze_pre_forward(self, name, module, module_input_output):
451
+ def analyze_forward_input(self, name, module, module_input_output):
392
452
  if not self.enable_kernel_dump:
393
453
  return
394
454
  if is_gpu:
@@ -413,7 +473,7 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
413
473
  return
414
474
  self.start_kernel_dump(self.config.kernel_config_path)
415
475
 
416
- def analyze_forward(self, name, module, module_input_output):
476
+ def analyze_forward_output(self, name, module, module_input_output):
417
477
  if not self.enable_kernel_dump:
418
478
  return
419
479
  if self.config.is_backward_kernel_dump:
@@ -15,10 +15,12 @@
15
15
 
16
16
  import csv
17
17
  import os
18
+ import numpy as np
18
19
 
19
20
  from msprobe.core.common.const import Const, FileCheckConst
20
- from msprobe.core.common.file_utils import change_mode, FileOpen, save_json
21
+ from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json
21
22
  from msprobe.core.common.log import logger
23
+ from msprobe.core.common.exceptions import MsprobeException
22
24
 
23
25
 
24
26
  class DataWriter:
@@ -115,3 +117,29 @@ class DataWriter:
115
117
  self.write_stack_info_json(self.stack_file_path)
116
118
  if self.cache_construct:
117
119
  self.write_construct_info_json(self.construct_file_path)
120
+
121
+ def fill_stack_tensor_data(self):
122
+ self.process_stat_data_recursive(self.cache_data)
123
+
124
+ def process_stat_data_recursive(self, data, depth=0):
125
+ if depth > Const.MAX_DEPTH:
126
+ logger.error(f"The maximum depth of recursive process stat data, {Const.MAX_DEPTH} is reached.")
127
+ raise MsprobeException(MsprobeException.RECURSION_LIMIT_ERROR)
128
+ if isinstance(data, dict):
129
+ if "tensor_stat" in data.keys():
130
+ tensor_stat = data["tensor_stat"]
131
+ if len(tensor_stat) != Const.TENSOR_STAT_LEN or len(tensor_stat[0]) != len(tensor_stat[1]):
132
+ logger.warning("Some bad data in async dump")
133
+ else:
134
+ tensor_stat_index, tensor_stat_data = tensor_stat[0], tensor_stat[1]
135
+ if hasattr(tensor_stat_data, "device") and tensor_stat_data.device != Const.CPU_LOWERCASE:
136
+ tensor_stat_data = tensor_stat_data.cpu()
137
+ for index, stat in zip(tensor_stat_index, tensor_stat_data):
138
+ data.update({index, stat.item()})
139
+ del data["tensor_stat"]
140
+ else:
141
+ for key in data.keys():
142
+ self.process_stat_data_recursive(data[key], depth + 1)
143
+ elif isinstance(data, (list, tuple)):
144
+ for i in data:
145
+ self.process_stat_data_recursive(i, depth + 1)