mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -13,7 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
 
16
+ import os
16
17
  import zlib
18
+ from concurrent.futures import ThreadPoolExecutor
17
19
 
18
20
  import mindspore as ms
19
21
  from mindspore import mint, ops, hal
@@ -53,6 +55,11 @@ class MindsporeDataProcessor(BaseDataProcessor):
53
55
  }
54
56
  self._async_dump_cache = {}
55
57
  self.api_register = get_api_register()
58
+ self._crc_executor = ThreadPoolExecutor(max_workers=os.cpu_count() // 2)
59
+
60
+ @staticmethod
61
+ def compute_crc32_bytes(tensor_bytes):
62
+ return f"{zlib.crc32(tensor_bytes):08x}"
56
63
 
57
64
  @staticmethod
58
65
  def get_md5_for_tensor(x):
@@ -65,52 +72,6 @@ class MindsporeDataProcessor(BaseDataProcessor):
65
72
  def analyze_dtype_in_kwargs(element):
66
73
  return {"type": "mindspore.dtype", "value": str(element)}
67
74
 
68
- @staticmethod
69
- def get_stat_info_sync(data):
70
- tensor_stat = TensorStatInfo()
71
- if data.dtype == ms.bool_:
72
- data_np = data.asnumpy()
73
- tensor_stat.max = np.max(data_np).item()
74
- tensor_stat.min = np.min(data_np).item()
75
- elif not data.shape:
76
- tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
77
- elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
78
- data_abs = np.abs(data.asnumpy())
79
- tensor_stat.max = np.max(data_abs).item()
80
- tensor_stat.min = np.min(data_abs).item()
81
- tensor_stat.mean = np.mean(data_abs).item()
82
- tensor_stat.norm = np.linalg.norm(data_abs).item()
83
- else:
84
- if not ops.is_floating_point(data) or data.dtype == ms.float64:
85
- data = data.to(ms.float32)
86
- get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
87
- tensor_stat.max = mint.max(data)
88
- tensor_stat.min = mint.min(data)
89
- tensor_stat.mean = mint.mean(data)
90
- tensor_stat.norm = get_norm_value(data)
91
- return tensor_stat
92
-
93
- @staticmethod
94
- def get_stat_info_async(data):
95
- tensor_stat = TensorStatInfo()
96
- if data.dtype == ms.bool_:
97
- tensor_stat.max = mint.any(data)
98
- tensor_stat.min = mint.all(data)
99
- elif not data.shape:
100
- tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
101
- elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
102
- logger.warning("Async dump do not support complex data!")
103
- return tensor_stat
104
- else:
105
- if not ops.is_floating_point(data) or data.dtype == ms.float64:
106
- data = data.to(ms.float32)
107
- get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
108
- tensor_stat.max = mint.max(data)
109
- tensor_stat.min = mint.min(data)
110
- tensor_stat.mean = mint.mean(data)
111
- tensor_stat.norm = get_norm_value(data)
112
- return tensor_stat
113
-
114
75
  @staticmethod
115
76
  def is_hookable_element(element):
116
77
  return hasattr(element, "register_hook") and callable(element.register_hook)
@@ -147,14 +108,37 @@ class MindsporeDataProcessor(BaseDataProcessor):
147
108
  self.api_register.restore_inner_used_api()
148
109
  tensor_stat = TensorStatInfo()
149
110
  if data.numel() == 0:
150
- stat_info = tensor_stat
151
- else:
111
+ pass
112
+ elif data.dtype == ms.bool_:
113
+ if self.config.async_dump:
114
+ tensor_stat.max = mint.any(data)
115
+ tensor_stat.min = mint.all(data)
116
+ else:
117
+ data_np = data.asnumpy()
118
+ tensor_stat.max = np.max(data_np).item()
119
+ tensor_stat.min = np.min(data_np).item()
120
+ elif not data.shape:
121
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.copy()
122
+ elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
152
123
  if self.config.async_dump:
153
- stat_info = MindsporeDataProcessor.get_stat_info_async(data)
124
+ logger.warning("Async dump do not support complex data!")
154
125
  else:
155
- stat_info = MindsporeDataProcessor.get_stat_info_sync(data)
126
+ data_abs = np.abs(data.asnumpy())
127
+ tensor_stat.max = np.max(data_abs).item()
128
+ tensor_stat.min = np.min(data_abs).item()
129
+ tensor_stat.mean = np.mean(data_abs).item()
130
+ tensor_stat.norm = np.linalg.norm(data_abs).item()
131
+ else:
132
+ if self.config.precision == Const.DUMP_PRECISION_HIGH or not ops.is_floating_point(
133
+ data) or data.dtype == ms.float64:
134
+ data = data.to(ms.float32)
135
+ get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
136
+ tensor_stat.max = mint.max(data)
137
+ tensor_stat.min = mint.min(data)
138
+ tensor_stat.mean = mint.mean(data)
139
+ tensor_stat.norm = get_norm_value(data)
156
140
  self.api_register.register_inner_used_api()
157
- return stat_info
141
+ return tensor_stat
158
142
 
159
143
  def analyze_single_element(self, element, suffix_stack):
160
144
  if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
@@ -211,8 +195,18 @@ class MindsporeDataProcessor(BaseDataProcessor):
211
195
  tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index})
212
196
 
213
197
  if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
214
- tensor_md5 = self.get_md5_for_tensor(tensor)
215
- tensor_json.update({Const.MD5: tensor_md5})
198
+ tensor = convert_bf16_to_fp32(tensor)
199
+ # 拷贝并搬到 CPU
200
+ tensor_bytes = tensor.asnumpy()
201
+
202
+ future = self._crc_executor.submit(
203
+ MindsporeDataProcessor.compute_crc32_bytes,
204
+ tensor_bytes
205
+ )
206
+
207
+ crc_placeholder = self.data_writer.append_crc32_to_buffer(future)
208
+ tensor_json[Const.MD5_INDEX] = crc_placeholder
209
+
216
210
  return tensor_json
217
211
 
218
212
  def _analyze_and_save_tensor(self, tensor, suffix):
@@ -13,7 +13,11 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import ctypes
17
+ import os
16
18
  import zlib
19
+ from collections.abc import Iterable
20
+ from concurrent.futures import ThreadPoolExecutor
17
21
  from dataclasses import asdict
18
22
  from typing import List
19
23
 
@@ -23,11 +27,10 @@ from torch import distributed as dist
23
27
  from torch.distributed.distributed_c10d import _get_default_group
24
28
 
25
29
  from msprobe.core.common.const import Const
30
+ from msprobe.core.common.decorator import recursion_depth_decorator
26
31
  from msprobe.core.common.exceptions import MsprobeException
27
- from msprobe.core.common.file_utils import path_len_exceeds_limit
28
32
  from msprobe.core.common.log import logger
29
- from msprobe.core.common.utils import convert_tuple
30
- from msprobe.core.common.decorator import recursion_depth_decorator
33
+ from msprobe.core.common.utils import convert_tuple, is_int
31
34
  from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
32
35
  ModuleForwardInputsOutputs, TensorStatInfo
33
36
  from msprobe.pytorch.common.utils import save_pt
@@ -40,6 +43,84 @@ except ImportError:
40
43
  is_gpu = True
41
44
 
42
45
 
46
+ class TensorHandler:
47
+ def __init__(self):
48
+ self.has_dtensor = hasattr(dist, "tensor") and hasattr(dist.tensor, "DTensor")
49
+ self.has_fake_tensor = hasattr(torch, "_subclasses") and hasattr(torch._subclasses, "fake_tensor")
50
+ self.has_async_collective_tensor = hasattr(dist, "_functional_collectives") and \
51
+ hasattr(dist._functional_collectives, "AsyncCollectiveTensor")
52
+
53
+ @staticmethod
54
+ def free_tensor(tensor, tensor_name):
55
+ try:
56
+ tensor.untyped_storage().resize_(0)
57
+ except Exception as e:
58
+ logger.warning(f"Failed to free tensor: {tensor_name}, the detail info: {e}.")
59
+
60
+ def is_dtensor(self, tensor):
61
+ return self.has_dtensor and isinstance(tensor, dist.tensor.DTensor)
62
+
63
+ def is_fake_tensor(self, tensor):
64
+ return self.has_fake_tensor and isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor)
65
+
66
+ def is_async_collective_tensor(self, tensor):
67
+ return self.has_async_collective_tensor and \
68
+ isinstance(tensor, dist._functional_collectives.AsyncCollectiveTensor)
69
+
70
+ def is_empty_data(self, tensor):
71
+ return tensor.is_meta or self.is_fake_tensor(tensor) or self.is_async_collective_tensor(tensor)
72
+
73
+ def convert_common_tensor(self, tensor):
74
+ if self.is_dtensor(tensor):
75
+ return tensor.to_local()
76
+ if self.is_fake_tensor(tensor):
77
+ logger.debug("FakeTensor cannot be converted to torch.Tensor type.")
78
+ return tensor
79
+ return tensor
80
+
81
+ def get_tensor_type(self, tensor):
82
+ if self.is_dtensor(tensor):
83
+ return Const.DTENSOR_TYPE
84
+ if self.is_fake_tensor(tensor):
85
+ return Const.FAKE_TENSOR_TYPE
86
+ if self.is_async_collective_tensor(tensor):
87
+ return Const.AC_TENSOR_TYPE
88
+ return Const.TENSOR_TYPE
89
+
90
+ def get_dtensor_info(self, tensor):
91
+ dtensor_info = {}
92
+ if not self.is_dtensor(tensor):
93
+ return dtensor_info
94
+ if hasattr(tensor, "device_mesh") and tensor.device_mesh:
95
+ dtensor_info.update({"device_mesh": tensor.device_mesh.mesh.tolist()})
96
+
97
+ placements = []
98
+ if hasattr(tensor, "placements") and isinstance(tensor.placements, Iterable):
99
+ for placement in tensor.placements:
100
+ if placement.is_shard() and is_int(placement.dim):
101
+ placements.append({"Shard": {"dim": placement.dim}})
102
+ continue
103
+ if placement.is_replicate():
104
+ placements.append({"Replicate": {}})
105
+ continue
106
+ if placement.is_partial() and isinstance(placement.reduce_op, str):
107
+ placements.append({"Partial": {"reduce_op": placement.reduce_op}})
108
+ dtensor_info.update({"placements": placements})
109
+ return dtensor_info
110
+
111
+ def save_tensor(self, tensor, file_path):
112
+ common_tensor = self.convert_common_tensor(tensor)
113
+ if self.is_empty_data(common_tensor):
114
+ logger.debug(f"Saving fake tensor or meta tensor is not supported, the current tensor is {file_path}.")
115
+ return
116
+ if common_tensor.untyped_storage().data_ptr() == 0:
117
+ logger.debug(f"Saving null-pointer tensor is not supported, the current tensor is {file_path}.")
118
+ return
119
+ saved_tensor = common_tensor.clone().contiguous().detach()
120
+ save_pt(saved_tensor, file_path)
121
+ self.free_tensor(saved_tensor, file_path)
122
+
123
+
43
124
  class PytorchDataProcessor(BaseDataProcessor):
44
125
  pytorch_special_type = (
45
126
  torch.device,
@@ -65,6 +146,8 @@ class PytorchDataProcessor(BaseDataProcessor):
65
146
  "dtype": self.analyze_dtype_in_kwargs
66
147
  }
67
148
  self._async_dump_cache = {}
149
+ self.tensor_handler = TensorHandler()
150
+ self._crc_executor = ThreadPoolExecutor(max_workers=os.cpu_count() // 2)
68
151
 
69
152
  @staticmethod
70
153
  def get_md5_for_tensor(x):
@@ -74,6 +157,64 @@ class PytorchDataProcessor(BaseDataProcessor):
74
157
  crc32_hash = zlib.crc32(tensor_bytes)
75
158
  return f"{crc32_hash:08x}"
76
159
 
160
+ @staticmethod
161
+ def tensor_bytes_view_cpu(t: torch.Tensor):
162
+ """
163
+ 返回 t 在当前 dtype 下的原始字节视图(优先零拷贝)。
164
+ 需保证:t 已在 CPU 且是 contiguous。
165
+ 可能返回 memoryview 或 bytes(兜底拷贝)或者 转为numpy,均可被 zlib.crc32 接受。
166
+ """
167
+
168
+ nbytes = t.numel() * t.element_size()
169
+ byte_offset = t.storage_offset() * t.element_size()
170
+
171
+ if nbytes == 0:
172
+ return memoryview(b"")
173
+
174
+ storage = t.untyped_storage()
175
+
176
+ # ctypes 指针构造 memoryview(零拷贝 FFI)
177
+ try:
178
+ addr = storage.data_ptr() + byte_offset
179
+ buf = (ctypes.c_ubyte * nbytes).from_address(addr)
180
+ mv3 = memoryview(buf)
181
+
182
+ return mv3
183
+ except Exception as e1:
184
+ logger.warning(f"path_A_failed: {e1}.")
185
+
186
+ try:
187
+ data = ctypes.string_at(storage.data_ptr() + byte_offset, nbytes)
188
+
189
+ return data # bytes 也可直接用于 zlib.crc32
190
+ except Exception as e2:
191
+ logger.warning(f"path_B_failed: {e2}.")
192
+
193
+ try:
194
+ if t.dtype == torch.bfloat16:
195
+ t = t.float()
196
+ data = t.numpy()
197
+
198
+ return data
199
+ except Exception as e3:
200
+ logger.warning(f"path_C_failed: {e3}.")
201
+ return memoryview(b"")
202
+
203
+ @staticmethod
204
+ def compute_crc32_from_tensor(t: torch.Tensor) -> str:
205
+ """
206
+ 直接对 Tensor 原始字节做 CRC32。
207
+ :
208
+ - "raw": 保持 bfloat16 原始 16bit 字节(推荐,避免升精/增容)
209
+ """
210
+
211
+ # 取得字节视图(含多级回退),然后做 CRC
212
+ mv = PytorchDataProcessor.tensor_bytes_view_cpu(t)
213
+
214
+ crc = zlib.crc32(mv)
215
+
216
+ return f"{crc:08x}"
217
+
77
218
  @staticmethod
78
219
  def analyze_device_in_kwargs(element):
79
220
  single_arg = {}
@@ -94,80 +235,6 @@ class PytorchDataProcessor(BaseDataProcessor):
94
235
  def analyze_dtype_in_kwargs(element):
95
236
  return {"type": "torch.dtype", "value": str(element)}
96
237
 
97
- @staticmethod
98
- def get_stat_info_async(data):
99
- tensor_stat = TensorStatInfo()
100
- if torch.is_complex(data):
101
- logger.warning("Async dump do not support complex data!")
102
- return tensor_stat
103
- elif data.dtype == torch.bool:
104
- tensor_stat.max = torch.any(data)
105
- tensor_stat.min = torch.all(data)
106
- elif not data.shape:
107
- tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
108
- else:
109
- if data.dtype == torch.float64 or not data.is_floating_point():
110
- data = data.float()
111
- tensor_stat.max = torch.max(data)
112
- tensor_stat.min = torch.min(data)
113
- tensor_stat.mean = torch.mean(data)
114
- tensor_stat.norm = torch.norm(data)
115
- return tensor_stat
116
-
117
- @staticmethod
118
- def get_stat_info_sync(data):
119
- tensor_stat = TensorStatInfo()
120
- if torch.is_complex(data):
121
- data_np = data.cpu().numpy()
122
- data_abs = np.abs(data_np)
123
- tensor_stat.max = np.max(data_abs).item()
124
- tensor_stat.min = np.min(data_abs).item()
125
- tensor_stat.mean = np.mean(data_abs).item()
126
- elif data.dtype == torch.bool:
127
- tensor_stat.max = torch.any(data)
128
- tensor_stat.min = torch.all(data)
129
- elif not data.shape:
130
- tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
131
- else:
132
- if data.dtype == torch.float64 or not data.is_floating_point():
133
- data = data.float()
134
- tensor_stat.max = torch.max(data)
135
- tensor_stat.min = torch.min(data)
136
- tensor_stat.mean = torch.mean(data)
137
- tensor_stat.norm = torch.norm(data)
138
- return tensor_stat
139
-
140
- @staticmethod
141
- def get_stat_info(data, async_dump=False):
142
- tensor_stat = TensorStatInfo()
143
- if data.is_meta:
144
- return tensor_stat
145
- data_clone = data.detach()
146
- if not data_clone.numel() or not data_clone.data_ptr():
147
- return tensor_stat
148
- else:
149
- if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump:
150
- return PytorchDataProcessor.get_stat_info_sync(data_clone)
151
- else:
152
- return PytorchDataProcessor.get_stat_info_async(data_clone)
153
-
154
- @staticmethod
155
- def handle_tensor_extremum_nan_inf(tensor, operator):
156
- data_clone = tensor.detach()
157
- data_nan = torch.isnan(data_clone)
158
- if int(torch.sum(data_nan)) == data_clone.numel():
159
- return float('nan')
160
-
161
- finite_mask = torch.isfinite(data_clone)
162
- if int(torch.sum(finite_mask)) > 0:
163
- finite_values = data_clone[finite_mask]
164
- return torch.max(finite_values).item() if operator == 'max' else \
165
- torch.min(finite_values).item()
166
- else:
167
- data_no_nan = data_clone[~data_nan]
168
- return torch.max(data_no_nan).item() if operator == 'max' else \
169
- torch.min(data_no_nan).item()
170
-
171
238
  @staticmethod
172
239
  def process_group_hash(arg):
173
240
  group_ranks = dist.get_process_group_ranks(arg)
@@ -214,9 +281,40 @@ class PytorchDataProcessor(BaseDataProcessor):
214
281
  def get_special_types(cls):
215
282
  return super().get_special_types() + cls.pytorch_special_type
216
283
 
284
+ def get_stat_info(self, data, async_dump=False, precision=Const.DUMP_PRECISION_LOW):
285
+ tensor_stat = TensorStatInfo()
286
+ if self.tensor_handler.is_empty_data(data):
287
+ return tensor_stat
288
+ data_clone = data.detach()
289
+ if not data_clone.numel() or not data_clone.data_ptr():
290
+ return tensor_stat
291
+ if torch.is_complex(data_clone):
292
+ if async_dump:
293
+ logger.warning("Async dump do not support complex data!")
294
+ return tensor_stat
295
+ data_np = data_clone.cpu().numpy()
296
+ data_abs = np.abs(data_np)
297
+ tensor_stat.max = np.max(data_abs).item()
298
+ tensor_stat.min = np.min(data_abs).item()
299
+ tensor_stat.mean = np.mean(data_abs).item()
300
+ elif data_clone.dtype == torch.bool:
301
+ tensor_stat.max = torch.any(data_clone)
302
+ tensor_stat.min = torch.all(data_clone)
303
+ elif not data_clone.shape:
304
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.clone()
305
+ else:
306
+ if (precision == Const.DUMP_PRECISION_HIGH or data_clone.dtype == torch.float64
307
+ or not data_clone.is_floating_point()):
308
+ data_clone = data_clone.float()
309
+ tensor_stat.max = torch.max(data_clone)
310
+ tensor_stat.min = torch.min(data_clone)
311
+ tensor_stat.mean = torch.mean(data_clone)
312
+ tensor_stat.norm = torch.norm(data_clone)
313
+ return tensor_stat
314
+
217
315
  def dump_async_data(self):
218
316
  for file_path, tensor in self._async_dump_cache.items():
219
- save_pt(tensor.contiguous(), file_path)
317
+ self.tensor_handler.save_tensor(tensor, file_path)
220
318
  self._async_dump_cache.clear()
221
319
 
222
320
  def analyze_single_element(self, element, suffix_stack):
@@ -256,11 +354,12 @@ class PytorchDataProcessor(BaseDataProcessor):
256
354
  return p2pop_info
257
355
 
258
356
  def _analyze_tensor(self, tensor, suffix):
259
- tensor_stat = self.get_stat_info(tensor, self.config.async_dump)
357
+ common_tensor = self.tensor_handler.convert_common_tensor(tensor)
358
+ tensor_stat = self.get_stat_info(common_tensor, self.config.async_dump, self.config.precision)
260
359
  tensor_json = {}
261
- tensor_json.update({'type': 'torch.Tensor'})
262
- tensor_json.update({'dtype': str(tensor.dtype)})
263
- tensor_json.update({"shape": tensor.shape})
360
+ tensor_json.update({'type': self.tensor_handler.get_tensor_type(tensor)})
361
+ tensor_json.update({'dtype': str(common_tensor.dtype)})
362
+ tensor_json.update({"shape": common_tensor.shape})
264
363
 
265
364
  stat_values = [
266
365
  tensor_stat.max,
@@ -272,26 +371,64 @@ class PytorchDataProcessor(BaseDataProcessor):
272
371
 
273
372
  tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index})
274
373
  tensor_json.update({"requires_grad": tensor.requires_grad})
374
+ if self.tensor_handler.is_dtensor(tensor):
375
+ dtensor_info = self.tensor_handler.get_dtensor_info(tensor)
376
+ tensor_json.update(dtensor_info)
275
377
 
276
378
  if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
277
- tensor_md5 = self.get_md5_for_tensor(tensor)
278
- tensor_json.update({Const.MD5: tensor_md5})
379
+ tensor_md5 = None
380
+ if not self.tensor_handler.is_empty_data(tensor):
381
+ t_cpu = common_tensor
382
+
383
+ # 根据设备类型做同步,确保数据已准备好
384
+ if t_cpu.device.type == "cuda":
385
+ t_cpu = t_cpu.to("cpu", non_blocking=True)
386
+ torch.cuda.synchronize()
387
+ # 先异步搬运再进行同步可以显著提升性能
388
+ elif t_cpu.device.type == "npu":
389
+ t_cpu = t_cpu.to("cpu", non_blocking=True)
390
+ torch.npu.synchronize()
391
+
392
+ t_cpu = t_cpu.detach()
393
+ if not t_cpu.is_contiguous():
394
+ t_cpu = t_cpu.contiguous()
395
+
396
+ future = self._crc_executor.submit(
397
+ PytorchDataProcessor.compute_crc32_from_tensor,
398
+ t_cpu
399
+ )
400
+
401
+ crc_placeholder = self.data_writer.append_crc32_to_buffer(future)
402
+ tensor_json[Const.MD5_INDEX] = crc_placeholder
403
+ else:
404
+ logger.debug(
405
+ "Calculating the md5 value of fake tensor or meta tensor is not supported, "
406
+ f"the current api/module name is {self.current_api_or_module_name}."
407
+ )
408
+ tensor_json.update({Const.MD5: tensor_md5})
279
409
  return tensor_json
280
410
 
281
411
  def _analyze_and_save_tensor(self, tensor, suffix):
282
412
  dump_data_name, file_path = self.get_save_file_path(suffix)
283
413
  single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix)
414
+ common_tensor = self.tensor_handler.convert_common_tensor(tensor)
415
+ if self.tensor_handler.is_empty_data(common_tensor):
416
+ logger.debug(f"Saving fake tensor or meta tensor is not supported, the current tensor is {file_path}.")
417
+ return single_arg
418
+ if common_tensor.untyped_storage().data_ptr() == 0:
419
+ logger.debug(f"Saving null-pointer tensor is not supported, the current tensor is {file_path}.")
420
+ return single_arg
421
+
284
422
  single_arg.update({"data_name": dump_data_name})
285
423
  if self.config.async_dump:
286
- self._async_dump_cache[file_path] = tensor.clone().detach()
424
+ self._async_dump_cache[file_path] = common_tensor.clone().detach()
287
425
  else:
288
- saved_tensor = tensor.clone().contiguous().detach()
289
- save_pt(saved_tensor, file_path)
426
+ self.tensor_handler.save_tensor(common_tensor, file_path)
290
427
  return single_arg
291
428
 
292
429
  def _analyze_and_save_ndarray(self, ndarray, suffix):
293
430
  dump_data_name, file_path = self.get_save_file_path(suffix)
294
- save_pt(torch.tensor(ndarray), file_path)
431
+ self.tensor_handler.save_tensor(torch.tensor(ndarray), file_path)
295
432
  ndarray_json = PytorchDataProcessor._analyze_ndarray(ndarray, suffix)
296
433
  ndarray_json.update({"data_name": dump_data_name})
297
434
  return ndarray_json
@@ -382,7 +519,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
382
519
  self._analyze_maybe_overflow_flag()
383
520
  if self.has_overflow:
384
521
  for file_path, tensor in self.cached_tensors_and_file_paths.items():
385
- save_pt(tensor.clone().contiguous().detach(), file_path)
522
+ self.tensor_handler.save_tensor(tensor, file_path)
386
523
  self.real_overflow_nums += 1
387
524
  if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums:
388
525
  logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, "
@@ -427,10 +564,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
427
564
 
428
565
  def _analyze_tensor(self, tensor, suffix):
429
566
  dump_data_name, file_path = self.get_save_file_path(suffix)
430
- if not path_len_exceeds_limit(file_path):
431
- self.cached_tensors_and_file_paths.update({file_path: tensor})
432
- else:
433
- logger.warning(f'The file path {file_path} length exceeds limit.')
567
+ self.cached_tensors_and_file_paths.update({file_path: tensor})
434
568
  single_arg = super()._analyze_tensor(tensor, suffix)
435
569
  single_arg.update({"data_name": dump_data_name})
436
570
  if not self.has_overflow and self.support_inf_nan: