mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -17,13 +17,14 @@ import zlib
17
17
 
18
18
  import mindspore as ms
19
19
  from mindspore import mint, ops, hal
20
+ from mindspore.mint import distributed
20
21
  from mindspore._c_expression.typing import Number
21
22
  import numpy as np
22
23
 
23
24
  from msprobe.core.common.const import Const
24
25
  from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo,
25
26
  ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs)
26
- from msprobe.core.common.file_utils import path_len_exceeds_limit, save_npy
27
+ from msprobe.core.common.file_utils import path_len_exceeds_limit
27
28
  from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy
28
29
  from msprobe.mindspore.common.log import logger
29
30
  from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
@@ -36,7 +37,7 @@ except ImportError:
36
37
 
37
38
 
38
39
  class MindsporeDataProcessor(BaseDataProcessor):
39
- mindspore_special_type = tuple([ms.Tensor, Number])
40
+ mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp])
40
41
 
41
42
  def __init__(self, config, data_writer):
42
43
  super().__init__(config, data_writer)
@@ -65,7 +66,7 @@ class MindsporeDataProcessor(BaseDataProcessor):
65
66
  tensor_stat.max = np.max(data_np).item()
66
67
  tensor_stat.min = np.min(data_np).item()
67
68
  elif not data.shape:
68
- tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
69
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
69
70
  elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
70
71
  data_abs = np.abs(data.asnumpy())
71
72
  tensor_stat.max = np.max(data_abs).item()
@@ -76,38 +77,52 @@ class MindsporeDataProcessor(BaseDataProcessor):
76
77
  if not ops.is_floating_point(data) or data.dtype == ms.float64:
77
78
  data = data.to(ms.float32)
78
79
  get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
79
- tensor_stat.max = mint.max(data).item()
80
- tensor_stat.min = mint.min(data).item()
81
- tensor_stat.mean = mint.mean(data).item()
82
- tensor_stat.norm = get_norm_value(data).item()
80
+ tensor_stat.max = mint.max(data)
81
+ tensor_stat.min = mint.min(data)
82
+ tensor_stat.mean = mint.mean(data)
83
+ tensor_stat.norm = get_norm_value(data)
83
84
  return tensor_stat
84
85
 
85
86
  @staticmethod
86
87
  def get_stat_info_async(data):
87
88
  tensor_stat = TensorStatInfo()
88
- if data.dtype == ms.complex64 or data.dtype == ms.complex128:
89
+ if data.dtype == ms.bool_:
90
+ tensor_stat.max = mint.any(data)
91
+ tensor_stat.min = mint.all(data)
92
+ elif not data.shape:
93
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
94
+ elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
89
95
  logger.warning("Async dump do not support complex data!")
90
96
  return tensor_stat
91
- elif data.dtype == ms.bool_:
92
- tensor_stat.stack_tensor_stat = (["Max", "Min"], ops.stack([data.any(), data.all()]))
93
- elif not data.shape:
94
- tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], ops.stack([data, data, data, data]))
95
97
  else:
96
98
  if not ops.is_floating_point(data) or data.dtype == ms.float64:
97
99
  data = data.to(ms.float32)
98
100
  get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
99
- tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], ops.stack(
100
- [mint.max(data), mint.min(data), mint.mean(data), get_norm_value(data)]))
101
+ tensor_stat.max = mint.max(data)
102
+ tensor_stat.min = mint.min(data)
103
+ tensor_stat.mean = mint.mean(data)
104
+ tensor_stat.norm = get_norm_value(data)
101
105
  return tensor_stat
102
106
 
103
107
  @staticmethod
104
108
  def is_hookable_element(element):
105
109
  return hasattr(element, "register_hook") and callable(element.register_hook)
106
110
 
111
+ @staticmethod
112
+ def process_group_hash(arg):
113
+ group_ranks = distributed.get_process_group_ranks(arg)
114
+ group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8'))
115
+ return f"{group_ranks_hash:08x}"
116
+
107
117
  @classmethod
108
118
  def get_special_types(cls):
109
119
  return super().get_special_types() + cls.mindspore_special_type
110
120
 
121
+ def dump_async_data(self):
122
+ for file_path, tensor in self._async_dump_cache.items():
123
+ save_tensor_as_npy(tensor, file_path)
124
+ self._async_dump_cache.clear()
125
+
111
126
  def get_stat_info(self, data):
112
127
  self.api_register.restore_inner_used_api()
113
128
  tensor_stat = TensorStatInfo()
@@ -125,19 +140,34 @@ class MindsporeDataProcessor(BaseDataProcessor):
125
140
  if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
126
141
  return self.mindspore_object_key[suffix_stack[-1]](element)
127
142
 
128
- converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
129
- if converted_numpy is not element:
130
- return {"type": numpy_type, "value": converted_numpy}
131
- if isinstance(element, Number):
132
- return self.analyze_dtype_in_kwargs(element)
133
- if isinstance(element, ms.Tensor):
134
- return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
135
- if isinstance(element, np.ndarray):
136
- return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
137
- if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
138
- return self._analyze_builtin(element)
143
+ suffix_str = Const.SEP.join(str(s) for s in suffix_stack)
144
+ type_analyzer = [
145
+ (MindsporeDataProcessor.builtin_type, self._analyze_builtin),
146
+ (ms.Tensor, lambda e: self._analyze_tensor(e, suffix_str)),
147
+ (Number, self.analyze_dtype_in_kwargs),
148
+ (MindsporeDataProcessor.np_type[:-1], self._analyze_numpy),
149
+ (np.ndarray, lambda e: self._analyze_ndarray(e, suffix_str)),
150
+ (distributed.P2POp, lambda e: self._analyze_p2pop(e, suffix_str))
151
+ ]
152
+ for type_key, analyze_fn in type_analyzer:
153
+ if isinstance(element, type_key):
154
+ return analyze_fn(element)
139
155
  return {}
140
156
 
157
+ def _analyze_p2pop(self, arg, suffix):
158
+ p2pop_info = {"class_type": "mindspore.mint.distributed.P2POp"}
159
+ try:
160
+ tensor_info = self._analyze_tensor(arg.tensor, suffix)
161
+ p2pop_info.update({"tensor": tensor_info})
162
+ p2pop_info.update({"op": arg.op})
163
+ p2pop_info.update({"peer": arg.peer})
164
+ p2pop_info.update({"tag": arg.tag})
165
+ group_id = self.process_group_hash(arg.group) if arg.group else None
166
+ p2pop_info.update({"group_id": group_id})
167
+ except Exception as e:
168
+ logger.warning(f"Failed to parse the P2POp content with error info: {e}.")
169
+ return p2pop_info
170
+
141
171
  def _analyze_tensor(self, tensor, suffix):
142
172
  tensor_stat = self.get_stat_info(tensor)
143
173
  tensor_json = {
@@ -146,32 +176,26 @@ class MindsporeDataProcessor(BaseDataProcessor):
146
176
  'shape': tensor.shape
147
177
  }
148
178
 
149
- if tensor_stat.stack_tensor_stat is None:
150
- tensor_json.update({'Max': self.transfer_type(tensor_stat.max)})
151
- tensor_json.update({'Min': self.transfer_type(tensor_stat.min)})
152
- tensor_json.update({'Mean': self.transfer_type(tensor_stat.mean)})
153
- tensor_json.update({'Norm': self.transfer_type(tensor_stat.norm)})
154
- else:
155
- tensor_json.update({'tensor_stat': tensor_stat.stack_tensor_stat})
179
+ # 将统计值存入全局 buffer,并返回占位索引
180
+ stat_values = [
181
+ tensor_stat.max,
182
+ tensor_stat.min,
183
+ tensor_stat.mean,
184
+ tensor_stat.norm
185
+ ]
186
+
187
+ placeholder_index = self.data_writer.append_stat_to_buffer(stat_values)
188
+
189
+ tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index})
190
+
156
191
  if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
157
192
  tensor_md5 = self.get_md5_for_tensor(tensor)
158
193
  tensor_json.update({Const.MD5: tensor_md5})
159
194
  return tensor_json
160
195
 
161
-
162
- class StatisticsDataProcessor(MindsporeDataProcessor):
163
- pass
164
-
165
-
166
- class TensorDataProcessor(MindsporeDataProcessor):
167
- def dump_async_data(self):
168
- for file_path, tensor in self._async_dump_cache.items():
169
- save_tensor_as_npy(tensor, file_path)
170
- self._async_dump_cache.clear()
171
-
172
- def _analyze_tensor(self, tensor, suffix):
196
+ def _analyze_and_save_tensor(self, tensor, suffix):
173
197
  dump_data_name, file_path = self.get_save_file_path(suffix)
174
- single_arg = super()._analyze_tensor(tensor, suffix)
198
+ single_arg = MindsporeDataProcessor._analyze_tensor(self, tensor, suffix)
175
199
  single_arg.update({"data_name": dump_data_name})
176
200
  if self.config.async_dump:
177
201
  self._async_dump_cache[file_path] = tensor.copy()
@@ -179,12 +203,27 @@ class TensorDataProcessor(MindsporeDataProcessor):
179
203
  save_tensor_as_npy(tensor, file_path)
180
204
  return single_arg
181
205
 
182
- def _analyze_numpy(self, ndarray, suffix):
183
- dump_data_name, file_path = self.get_save_file_path(suffix)
184
- save_npy(ndarray, file_path)
185
- ndarray_json = super()._analyze_numpy(ndarray, suffix)
186
- ndarray_json.update({"data_name": dump_data_name})
187
- return ndarray_json
206
+
207
+ class StatisticsDataProcessor(MindsporeDataProcessor):
208
+ def _analyze_tensor(self, tensor, suffix):
209
+ if any(item in self.current_api_or_module_name for item in self.config.tensor_list):
210
+ return self._analyze_and_save_tensor(tensor, suffix)
211
+ else:
212
+ return super()._analyze_tensor(tensor, suffix)
213
+
214
+ def _analyze_ndarray(self, ndarray, suffix):
215
+ if any(item in self.current_api_or_module_name for item in self.config.tensor_list):
216
+ return self._analyze_and_save_ndarray(ndarray, suffix)
217
+ else:
218
+ return super()._analyze_ndarray(ndarray, suffix)
219
+
220
+
221
+ class TensorDataProcessor(MindsporeDataProcessor):
222
+ def _analyze_tensor(self, tensor, suffix):
223
+ return self._analyze_and_save_tensor(tensor, suffix)
224
+
225
+ def _analyze_ndarray(self, ndarray, suffix):
226
+ return self._analyze_and_save_ndarray(ndarray, suffix)
188
227
 
189
228
 
190
229
  class OverflowCheckDataProcessor(MindsporeDataProcessor):
@@ -231,7 +270,7 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
231
270
  api_info_struct = super().analyze_backward(name, module, module_input_output)
232
271
  self.maybe_save_overflow_data()
233
272
  return api_info_struct if self.has_overflow else None
234
-
273
+
235
274
  def analyze_params(self, name, param_name, grad):
236
275
  self.has_overflow = False
237
276
  api_info_struct = super().analyze_params(name, param_name, grad)
@@ -249,11 +288,26 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
249
288
  self.cached_tensors_and_file_paths = {}
250
289
 
251
290
  def _analyze_maybe_overflow_tensor(self, tensor_json):
252
- if tensor_json['Max'] is None:
291
+ tensor_stat_index = tensor_json.get(Const.TENSOR_STAT_INDEX)
292
+ if tensor_stat_index is None:
293
+ logger.warning("tensor_stat_index does not exist in tensor_json.")
253
294
  return
254
- if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']):
295
+ max_tensor = self.data_writer.get_buffer_values_max(tensor_stat_index)
296
+ min_tensor = self.data_writer.get_buffer_values_min(tensor_stat_index)
297
+ if max_tensor is None or min_tensor is None:
298
+ return
299
+
300
+ def check_inf_nan(value):
301
+ # Use .item() if it's a tensor-like structure
302
+ if hasattr(value, "item"):
303
+ value = value.item()
304
+ return np.isinf(value) or np.isnan(value)
305
+
306
+ if check_inf_nan(max_tensor):
255
307
  self.has_overflow = True
256
- if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']):
308
+ return
309
+
310
+ if check_inf_nan(min_tensor):
257
311
  self.has_overflow = True
258
312
 
259
313
  def _analyze_tensor(self, tensor, suffix):
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import hashlib
17
16
  import zlib
18
17
  from dataclasses import asdict
19
18
  from typing import List
@@ -102,19 +101,17 @@ class PytorchDataProcessor(BaseDataProcessor):
102
101
  logger.warning("Async dump do not support complex data!")
103
102
  return tensor_stat
104
103
  elif data.dtype == torch.bool:
105
- tensor_stat.stack_tensor_stat = (["Max", "Min"], torch.stack(
106
- [torch.any(data), torch.all(data)]))
104
+ tensor_stat.max = torch.any(data)
105
+ tensor_stat.min = torch.all(data)
107
106
  elif not data.shape:
108
- tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([data, data, data, data]))
107
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
109
108
  else:
110
- if not data.is_floating_point() or data.dtype == torch.float64:
109
+ if data.dtype == torch.float64 or not data.is_floating_point():
111
110
  data = data.float()
112
- tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([
113
- torch.max(data),
114
- torch.min(data),
115
- torch.mean(data),
116
- torch.norm(data)
117
- ]))
111
+ tensor_stat.max = torch.max(data)
112
+ tensor_stat.min = torch.min(data)
113
+ tensor_stat.mean = torch.mean(data)
114
+ tensor_stat.norm = torch.norm(data)
118
115
  return tensor_stat
119
116
 
120
117
  @staticmethod
@@ -127,17 +124,17 @@ class PytorchDataProcessor(BaseDataProcessor):
127
124
  tensor_stat.min = np.min(data_abs).item()
128
125
  tensor_stat.mean = np.mean(data_abs).item()
129
126
  elif data.dtype == torch.bool:
130
- tensor_stat.max = torch.any(data).item()
131
- tensor_stat.min = torch.all(data).item()
127
+ tensor_stat.max = torch.any(data)
128
+ tensor_stat.min = torch.all(data)
132
129
  elif not data.shape:
133
- tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
130
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
134
131
  else:
135
- if not data.is_floating_point() or data.dtype == torch.float64:
132
+ if data.dtype == torch.float64 or not data.is_floating_point():
136
133
  data = data.float()
137
- tensor_stat.max = torch.max(data).item()
138
- tensor_stat.min = torch.min(data).item()
139
- tensor_stat.mean = torch.mean(data).item()
140
- tensor_stat.norm = torch.norm(data).item()
134
+ tensor_stat.max = torch.max(data)
135
+ tensor_stat.min = torch.min(data)
136
+ tensor_stat.mean = torch.mean(data)
137
+ tensor_stat.norm = torch.norm(data)
141
138
  return tensor_stat
142
139
 
143
140
  @staticmethod
@@ -174,12 +171,8 @@ class PytorchDataProcessor(BaseDataProcessor):
174
171
  @staticmethod
175
172
  def process_group_hash(arg):
176
173
  group_ranks = dist.get_process_group_ranks(arg)
177
- group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest()
178
- return group_ranks_hash
179
-
180
- @staticmethod
181
- def is_distributed_op(module):
182
- return getattr(module, "op_is_distributed", False)
174
+ group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8'))
175
+ return f"{group_ranks_hash:08x}"
183
176
 
184
177
  @staticmethod
185
178
  def is_hookable_element(element):
@@ -233,34 +226,31 @@ class PytorchDataProcessor(BaseDataProcessor):
233
226
  def get_special_types(cls):
234
227
  return super().get_special_types() + cls.pytorch_special_type
235
228
 
229
+ def dump_async_data(self):
230
+ for file_path, tensor in self._async_dump_cache.items():
231
+ save_pt(tensor.contiguous(), file_path)
232
+ self._async_dump_cache.clear()
233
+
236
234
  def analyze_single_element(self, element, suffix_stack):
237
235
  if suffix_stack and suffix_stack[-1] in self.torch_object_key:
238
236
  return self.torch_object_key[suffix_stack[-1]](element)
239
- if isinstance(element, torch.Size):
240
- return self._analyze_torch_size(element)
241
- if isinstance(element, torch.memory_format):
242
- return self._analyze_memory_format(element)
243
- if isinstance(element, dist.ProcessGroup):
244
- return self._analyze_process_group(element)
245
- if isinstance(element, dist.P2POp):
246
- return self._analyze_p2pop(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
247
- if isinstance(element, dist.ReduceOp):
248
- return self._analyze_reduce_op(element)
249
- converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
250
- if converted_numpy is not element:
251
- return {"type": numpy_type, "value": converted_numpy}
252
- if isinstance(element, torch.Tensor):
253
- return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
254
- if isinstance(element, np.ndarray):
255
- return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
256
- if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
257
- return self._analyze_builtin(element)
258
- return {}
259
237
 
260
- def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
261
- if self.is_distributed_op(module):
262
- module_input_output.update_output_with_args_and_kwargs()
263
- return super().analyze_forward_output(name, module, module_input_output)
238
+ suffix_str = Const.SEP.join(str(s) for s in suffix_stack)
239
+ type_analyzer = [
240
+ (PytorchDataProcessor.builtin_type, self._analyze_builtin),
241
+ (torch.Size, self._analyze_torch_size),
242
+ (torch.Tensor, lambda e: self._analyze_tensor(e, suffix_str)),
243
+ (torch.memory_format, self._analyze_memory_format),
244
+ (dist.ProcessGroup, self._analyze_process_group),
245
+ (dist.P2POp, lambda e: self._analyze_p2pop(e, suffix_str)),
246
+ (dist.ReduceOp, self._analyze_reduce_op),
247
+ (PytorchDataProcessor.np_type[:-1], self._analyze_numpy),
248
+ (np.ndarray, lambda e: self._analyze_ndarray(e, suffix_str)),
249
+ ]
250
+ for type_key, analyze_fn in type_analyzer:
251
+ if isinstance(element, type_key):
252
+ return analyze_fn(element)
253
+ return {}
264
254
 
265
255
  def _analyze_p2pop(self, arg, suffix):
266
256
  p2pop_info = {"class_type": "torch.distributed.P2POp"}
@@ -284,42 +274,26 @@ class PytorchDataProcessor(BaseDataProcessor):
284
274
  tensor_json.update({'type': 'torch.Tensor'})
285
275
  tensor_json.update({'dtype': dtype})
286
276
  tensor_json.update({"shape": tensor.shape})
287
- if tensor_stat.stack_tensor_stat is None:
288
- tensor_json.update({"Max": tensor_stat.max})
289
- tensor_json.update({"Min": tensor_stat.min})
290
- tensor_json.update({"Mean": tensor_stat.mean})
291
- tensor_json.update({"Norm": tensor_stat.norm})
292
- tensor_json.update({"requires_grad": tensor.requires_grad})
293
- if tensor_stat.max is not None:
294
- if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
295
- tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
296
- if tensor_stat.min is not None:
297
- if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
298
- tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
299
277
 
300
- else:
301
- tensor_json.update({"requires_grad": tensor.requires_grad})
302
- tensor_json.update({"tensor_stat": tensor_stat.stack_tensor_stat})
278
+ stat_values = [
279
+ tensor_stat.max,
280
+ tensor_stat.min,
281
+ tensor_stat.mean,
282
+ tensor_stat.norm
283
+ ]
284
+ placeholder_index = self.data_writer.append_stat_to_buffer(stat_values)
285
+
286
+ tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index})
287
+ tensor_json.update({"requires_grad": tensor.requires_grad})
303
288
 
304
289
  if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
305
290
  tensor_md5 = self.get_md5_for_tensor(tensor)
306
291
  tensor_json.update({Const.MD5: tensor_md5})
307
292
  return tensor_json
308
293
 
309
-
310
- class StatisticsDataProcessor(PytorchDataProcessor):
311
- pass
312
-
313
-
314
- class TensorDataProcessor(PytorchDataProcessor):
315
- def dump_async_data(self):
316
- for file_path, tensor in self._async_dump_cache.items():
317
- save_pt(tensor.contiguous(), file_path)
318
- self._async_dump_cache.clear()
319
-
320
- def _analyze_tensor(self, tensor, suffix):
294
+ def _analyze_and_save_tensor(self, tensor, suffix):
321
295
  dump_data_name, file_path = self.get_save_file_path(suffix)
322
- single_arg = super()._analyze_tensor(tensor, suffix)
296
+ single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix)
323
297
  single_arg.update({"data_name": dump_data_name})
324
298
  tensor, _ = self._cast_to_float_if_fp8(tensor)
325
299
  if self.config.async_dump:
@@ -329,14 +303,36 @@ class TensorDataProcessor(PytorchDataProcessor):
329
303
  save_pt(saved_tensor, file_path)
330
304
  return single_arg
331
305
 
332
- def _analyze_numpy(self, ndarray, suffix):
306
+ def _analyze_and_save_ndarray(self, ndarray, suffix):
333
307
  dump_data_name, file_path = self.get_save_file_path(suffix)
334
308
  save_pt(torch.tensor(ndarray), file_path)
335
- ndarray_json = super()._analyze_numpy(ndarray, suffix)
309
+ ndarray_json = PytorchDataProcessor._analyze_ndarray(ndarray, suffix)
336
310
  ndarray_json.update({"data_name": dump_data_name})
337
311
  return ndarray_json
338
312
 
339
313
 
314
+ class StatisticsDataProcessor(PytorchDataProcessor):
315
+ def _analyze_tensor(self, tensor, suffix):
316
+ if any(item in self.current_api_or_module_name for item in self.config.tensor_list):
317
+ return self._analyze_and_save_tensor(tensor, suffix)
318
+ else:
319
+ return super()._analyze_tensor(tensor, suffix)
320
+
321
+ def _analyze_ndarray(self, ndarray, suffix):
322
+ if any(item in self.current_api_or_module_name for item in self.config.tensor_list):
323
+ return self._analyze_and_save_ndarray(ndarray, suffix)
324
+ else:
325
+ return super()._analyze_ndarray(ndarray, suffix)
326
+
327
+
328
+ class TensorDataProcessor(PytorchDataProcessor):
329
+ def _analyze_tensor(self, tensor, suffix):
330
+ return self._analyze_and_save_tensor(tensor, suffix)
331
+
332
+ def _analyze_ndarray(self, ndarray, suffix):
333
+ return self._analyze_and_save_ndarray(ndarray, suffix)
334
+
335
+
340
336
  class OverflowCheckDataProcessor(PytorchDataProcessor):
341
337
  __slots__ = ["cached_tensors_and_file_paths"]
342
338
 
@@ -427,10 +423,22 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
427
423
  raise RuntimeError(f"overflow check failed") from e
428
424
 
429
425
  def _analyze_maybe_overflow_tensor(self, tensor_json):
430
- if tensor_json['Max'] is None or tensor_json['Min'] is None:
426
+ tensor_stat_index = tensor_json.get(Const.TENSOR_STAT_INDEX)
427
+ if tensor_stat_index is None:
428
+ logger.warning("tensor_stat_index does not exist in tensor_json.")
431
429
  return
432
- self.has_overflow = np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']) or \
433
- np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min'])
430
+ max_tensor = self.data_writer.get_buffer_values_max(tensor_stat_index)
431
+ min_tensor = self.data_writer.get_buffer_values_min(tensor_stat_index)
432
+
433
+ if max_tensor is None or min_tensor is None:
434
+ return
435
+
436
+ if torch.isinf(max_tensor) or torch.isnan(max_tensor):
437
+ self.has_overflow = True
438
+ return
439
+
440
+ if torch.isinf(min_tensor) or torch.isnan(min_tensor):
441
+ self.has_overflow = True
434
442
 
435
443
  def _analyze_tensor(self, tensor, suffix):
436
444
  dump_data_name, file_path = self.get_save_file_path(suffix)