mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -13,7 +13,9 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
|
|
16
|
+
import os
|
|
16
17
|
import zlib
|
|
18
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
17
19
|
|
|
18
20
|
import mindspore as ms
|
|
19
21
|
from mindspore import mint, ops, hal
|
|
@@ -53,6 +55,11 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
53
55
|
}
|
|
54
56
|
self._async_dump_cache = {}
|
|
55
57
|
self.api_register = get_api_register()
|
|
58
|
+
self._crc_executor = ThreadPoolExecutor(max_workers=os.cpu_count() // 2)
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def compute_crc32_bytes(tensor_bytes):
|
|
62
|
+
return f"{zlib.crc32(tensor_bytes):08x}"
|
|
56
63
|
|
|
57
64
|
@staticmethod
|
|
58
65
|
def get_md5_for_tensor(x):
|
|
@@ -65,52 +72,6 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
65
72
|
def analyze_dtype_in_kwargs(element):
|
|
66
73
|
return {"type": "mindspore.dtype", "value": str(element)}
|
|
67
74
|
|
|
68
|
-
@staticmethod
|
|
69
|
-
def get_stat_info_sync(data):
|
|
70
|
-
tensor_stat = TensorStatInfo()
|
|
71
|
-
if data.dtype == ms.bool_:
|
|
72
|
-
data_np = data.asnumpy()
|
|
73
|
-
tensor_stat.max = np.max(data_np).item()
|
|
74
|
-
tensor_stat.min = np.min(data_np).item()
|
|
75
|
-
elif not data.shape:
|
|
76
|
-
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
|
|
77
|
-
elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
|
|
78
|
-
data_abs = np.abs(data.asnumpy())
|
|
79
|
-
tensor_stat.max = np.max(data_abs).item()
|
|
80
|
-
tensor_stat.min = np.min(data_abs).item()
|
|
81
|
-
tensor_stat.mean = np.mean(data_abs).item()
|
|
82
|
-
tensor_stat.norm = np.linalg.norm(data_abs).item()
|
|
83
|
-
else:
|
|
84
|
-
if not ops.is_floating_point(data) or data.dtype == ms.float64:
|
|
85
|
-
data = data.to(ms.float32)
|
|
86
|
-
get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
|
|
87
|
-
tensor_stat.max = mint.max(data)
|
|
88
|
-
tensor_stat.min = mint.min(data)
|
|
89
|
-
tensor_stat.mean = mint.mean(data)
|
|
90
|
-
tensor_stat.norm = get_norm_value(data)
|
|
91
|
-
return tensor_stat
|
|
92
|
-
|
|
93
|
-
@staticmethod
|
|
94
|
-
def get_stat_info_async(data):
|
|
95
|
-
tensor_stat = TensorStatInfo()
|
|
96
|
-
if data.dtype == ms.bool_:
|
|
97
|
-
tensor_stat.max = mint.any(data)
|
|
98
|
-
tensor_stat.min = mint.all(data)
|
|
99
|
-
elif not data.shape:
|
|
100
|
-
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
|
|
101
|
-
elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
|
|
102
|
-
logger.warning("Async dump do not support complex data!")
|
|
103
|
-
return tensor_stat
|
|
104
|
-
else:
|
|
105
|
-
if not ops.is_floating_point(data) or data.dtype == ms.float64:
|
|
106
|
-
data = data.to(ms.float32)
|
|
107
|
-
get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
|
|
108
|
-
tensor_stat.max = mint.max(data)
|
|
109
|
-
tensor_stat.min = mint.min(data)
|
|
110
|
-
tensor_stat.mean = mint.mean(data)
|
|
111
|
-
tensor_stat.norm = get_norm_value(data)
|
|
112
|
-
return tensor_stat
|
|
113
|
-
|
|
114
75
|
@staticmethod
|
|
115
76
|
def is_hookable_element(element):
|
|
116
77
|
return hasattr(element, "register_hook") and callable(element.register_hook)
|
|
@@ -147,14 +108,37 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
147
108
|
self.api_register.restore_inner_used_api()
|
|
148
109
|
tensor_stat = TensorStatInfo()
|
|
149
110
|
if data.numel() == 0:
|
|
150
|
-
|
|
151
|
-
|
|
111
|
+
pass
|
|
112
|
+
elif data.dtype == ms.bool_:
|
|
113
|
+
if self.config.async_dump:
|
|
114
|
+
tensor_stat.max = mint.any(data)
|
|
115
|
+
tensor_stat.min = mint.all(data)
|
|
116
|
+
else:
|
|
117
|
+
data_np = data.asnumpy()
|
|
118
|
+
tensor_stat.max = np.max(data_np).item()
|
|
119
|
+
tensor_stat.min = np.min(data_np).item()
|
|
120
|
+
elif not data.shape:
|
|
121
|
+
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.copy()
|
|
122
|
+
elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
|
|
152
123
|
if self.config.async_dump:
|
|
153
|
-
|
|
124
|
+
logger.warning("Async dump do not support complex data!")
|
|
154
125
|
else:
|
|
155
|
-
|
|
126
|
+
data_abs = np.abs(data.asnumpy())
|
|
127
|
+
tensor_stat.max = np.max(data_abs).item()
|
|
128
|
+
tensor_stat.min = np.min(data_abs).item()
|
|
129
|
+
tensor_stat.mean = np.mean(data_abs).item()
|
|
130
|
+
tensor_stat.norm = np.linalg.norm(data_abs).item()
|
|
131
|
+
else:
|
|
132
|
+
if self.config.precision == Const.DUMP_PRECISION_HIGH or not ops.is_floating_point(
|
|
133
|
+
data) or data.dtype == ms.float64:
|
|
134
|
+
data = data.to(ms.float32)
|
|
135
|
+
get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
|
|
136
|
+
tensor_stat.max = mint.max(data)
|
|
137
|
+
tensor_stat.min = mint.min(data)
|
|
138
|
+
tensor_stat.mean = mint.mean(data)
|
|
139
|
+
tensor_stat.norm = get_norm_value(data)
|
|
156
140
|
self.api_register.register_inner_used_api()
|
|
157
|
-
return
|
|
141
|
+
return tensor_stat
|
|
158
142
|
|
|
159
143
|
def analyze_single_element(self, element, suffix_stack):
|
|
160
144
|
if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
|
|
@@ -211,8 +195,18 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
211
195
|
tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index})
|
|
212
196
|
|
|
213
197
|
if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
|
|
214
|
-
|
|
215
|
-
|
|
198
|
+
tensor = convert_bf16_to_fp32(tensor)
|
|
199
|
+
# 拷贝并搬到 CPU
|
|
200
|
+
tensor_bytes = tensor.asnumpy()
|
|
201
|
+
|
|
202
|
+
future = self._crc_executor.submit(
|
|
203
|
+
MindsporeDataProcessor.compute_crc32_bytes,
|
|
204
|
+
tensor_bytes
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
crc_placeholder = self.data_writer.append_crc32_to_buffer(future)
|
|
208
|
+
tensor_json[Const.MD5_INDEX] = crc_placeholder
|
|
209
|
+
|
|
216
210
|
return tensor_json
|
|
217
211
|
|
|
218
212
|
def _analyze_and_save_tensor(self, tensor, suffix):
|
|
@@ -13,7 +13,11 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import ctypes
|
|
17
|
+
import os
|
|
16
18
|
import zlib
|
|
19
|
+
from collections.abc import Iterable
|
|
20
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
17
21
|
from dataclasses import asdict
|
|
18
22
|
from typing import List
|
|
19
23
|
|
|
@@ -23,11 +27,10 @@ from torch import distributed as dist
|
|
|
23
27
|
from torch.distributed.distributed_c10d import _get_default_group
|
|
24
28
|
|
|
25
29
|
from msprobe.core.common.const import Const
|
|
30
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
26
31
|
from msprobe.core.common.exceptions import MsprobeException
|
|
27
|
-
from msprobe.core.common.file_utils import path_len_exceeds_limit
|
|
28
32
|
from msprobe.core.common.log import logger
|
|
29
|
-
from msprobe.core.common.utils import convert_tuple
|
|
30
|
-
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
33
|
+
from msprobe.core.common.utils import convert_tuple, is_int
|
|
31
34
|
from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
|
|
32
35
|
ModuleForwardInputsOutputs, TensorStatInfo
|
|
33
36
|
from msprobe.pytorch.common.utils import save_pt
|
|
@@ -40,6 +43,84 @@ except ImportError:
|
|
|
40
43
|
is_gpu = True
|
|
41
44
|
|
|
42
45
|
|
|
46
|
+
class TensorHandler:
|
|
47
|
+
def __init__(self):
|
|
48
|
+
self.has_dtensor = hasattr(dist, "tensor") and hasattr(dist.tensor, "DTensor")
|
|
49
|
+
self.has_fake_tensor = hasattr(torch, "_subclasses") and hasattr(torch._subclasses, "fake_tensor")
|
|
50
|
+
self.has_async_collective_tensor = hasattr(dist, "_functional_collectives") and \
|
|
51
|
+
hasattr(dist._functional_collectives, "AsyncCollectiveTensor")
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def free_tensor(tensor, tensor_name):
|
|
55
|
+
try:
|
|
56
|
+
tensor.untyped_storage().resize_(0)
|
|
57
|
+
except Exception as e:
|
|
58
|
+
logger.warning(f"Failed to free tensor: {tensor_name}, the detail info: {e}.")
|
|
59
|
+
|
|
60
|
+
def is_dtensor(self, tensor):
|
|
61
|
+
return self.has_dtensor and isinstance(tensor, dist.tensor.DTensor)
|
|
62
|
+
|
|
63
|
+
def is_fake_tensor(self, tensor):
|
|
64
|
+
return self.has_fake_tensor and isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor)
|
|
65
|
+
|
|
66
|
+
def is_async_collective_tensor(self, tensor):
|
|
67
|
+
return self.has_async_collective_tensor and \
|
|
68
|
+
isinstance(tensor, dist._functional_collectives.AsyncCollectiveTensor)
|
|
69
|
+
|
|
70
|
+
def is_empty_data(self, tensor):
|
|
71
|
+
return tensor.is_meta or self.is_fake_tensor(tensor) or self.is_async_collective_tensor(tensor)
|
|
72
|
+
|
|
73
|
+
def convert_common_tensor(self, tensor):
|
|
74
|
+
if self.is_dtensor(tensor):
|
|
75
|
+
return tensor.to_local()
|
|
76
|
+
if self.is_fake_tensor(tensor):
|
|
77
|
+
logger.debug("FakeTensor cannot be converted to torch.Tensor type.")
|
|
78
|
+
return tensor
|
|
79
|
+
return tensor
|
|
80
|
+
|
|
81
|
+
def get_tensor_type(self, tensor):
|
|
82
|
+
if self.is_dtensor(tensor):
|
|
83
|
+
return Const.DTENSOR_TYPE
|
|
84
|
+
if self.is_fake_tensor(tensor):
|
|
85
|
+
return Const.FAKE_TENSOR_TYPE
|
|
86
|
+
if self.is_async_collective_tensor(tensor):
|
|
87
|
+
return Const.AC_TENSOR_TYPE
|
|
88
|
+
return Const.TENSOR_TYPE
|
|
89
|
+
|
|
90
|
+
def get_dtensor_info(self, tensor):
|
|
91
|
+
dtensor_info = {}
|
|
92
|
+
if not self.is_dtensor(tensor):
|
|
93
|
+
return dtensor_info
|
|
94
|
+
if hasattr(tensor, "device_mesh") and tensor.device_mesh:
|
|
95
|
+
dtensor_info.update({"device_mesh": tensor.device_mesh.mesh.tolist()})
|
|
96
|
+
|
|
97
|
+
placements = []
|
|
98
|
+
if hasattr(tensor, "placements") and isinstance(tensor.placements, Iterable):
|
|
99
|
+
for placement in tensor.placements:
|
|
100
|
+
if placement.is_shard() and is_int(placement.dim):
|
|
101
|
+
placements.append({"Shard": {"dim": placement.dim}})
|
|
102
|
+
continue
|
|
103
|
+
if placement.is_replicate():
|
|
104
|
+
placements.append({"Replicate": {}})
|
|
105
|
+
continue
|
|
106
|
+
if placement.is_partial() and isinstance(placement.reduce_op, str):
|
|
107
|
+
placements.append({"Partial": {"reduce_op": placement.reduce_op}})
|
|
108
|
+
dtensor_info.update({"placements": placements})
|
|
109
|
+
return dtensor_info
|
|
110
|
+
|
|
111
|
+
def save_tensor(self, tensor, file_path):
|
|
112
|
+
common_tensor = self.convert_common_tensor(tensor)
|
|
113
|
+
if self.is_empty_data(common_tensor):
|
|
114
|
+
logger.debug(f"Saving fake tensor or meta tensor is not supported, the current tensor is {file_path}.")
|
|
115
|
+
return
|
|
116
|
+
if common_tensor.untyped_storage().data_ptr() == 0:
|
|
117
|
+
logger.debug(f"Saving null-pointer tensor is not supported, the current tensor is {file_path}.")
|
|
118
|
+
return
|
|
119
|
+
saved_tensor = common_tensor.clone().contiguous().detach()
|
|
120
|
+
save_pt(saved_tensor, file_path)
|
|
121
|
+
self.free_tensor(saved_tensor, file_path)
|
|
122
|
+
|
|
123
|
+
|
|
43
124
|
class PytorchDataProcessor(BaseDataProcessor):
|
|
44
125
|
pytorch_special_type = (
|
|
45
126
|
torch.device,
|
|
@@ -65,6 +146,8 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
65
146
|
"dtype": self.analyze_dtype_in_kwargs
|
|
66
147
|
}
|
|
67
148
|
self._async_dump_cache = {}
|
|
149
|
+
self.tensor_handler = TensorHandler()
|
|
150
|
+
self._crc_executor = ThreadPoolExecutor(max_workers=os.cpu_count() // 2)
|
|
68
151
|
|
|
69
152
|
@staticmethod
|
|
70
153
|
def get_md5_for_tensor(x):
|
|
@@ -74,6 +157,64 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
74
157
|
crc32_hash = zlib.crc32(tensor_bytes)
|
|
75
158
|
return f"{crc32_hash:08x}"
|
|
76
159
|
|
|
160
|
+
@staticmethod
|
|
161
|
+
def tensor_bytes_view_cpu(t: torch.Tensor):
|
|
162
|
+
"""
|
|
163
|
+
返回 t 在当前 dtype 下的原始字节视图(优先零拷贝)。
|
|
164
|
+
需保证:t 已在 CPU 且是 contiguous。
|
|
165
|
+
可能返回 memoryview 或 bytes(兜底拷贝)或者 转为numpy,均可被 zlib.crc32 接受。
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
nbytes = t.numel() * t.element_size()
|
|
169
|
+
byte_offset = t.storage_offset() * t.element_size()
|
|
170
|
+
|
|
171
|
+
if nbytes == 0:
|
|
172
|
+
return memoryview(b"")
|
|
173
|
+
|
|
174
|
+
storage = t.untyped_storage()
|
|
175
|
+
|
|
176
|
+
# ctypes 指针构造 memoryview(零拷贝 FFI)
|
|
177
|
+
try:
|
|
178
|
+
addr = storage.data_ptr() + byte_offset
|
|
179
|
+
buf = (ctypes.c_ubyte * nbytes).from_address(addr)
|
|
180
|
+
mv3 = memoryview(buf)
|
|
181
|
+
|
|
182
|
+
return mv3
|
|
183
|
+
except Exception as e1:
|
|
184
|
+
logger.warning(f"path_A_failed: {e1}.")
|
|
185
|
+
|
|
186
|
+
try:
|
|
187
|
+
data = ctypes.string_at(storage.data_ptr() + byte_offset, nbytes)
|
|
188
|
+
|
|
189
|
+
return data # bytes 也可直接用于 zlib.crc32
|
|
190
|
+
except Exception as e2:
|
|
191
|
+
logger.warning(f"path_B_failed: {e2}.")
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
if t.dtype == torch.bfloat16:
|
|
195
|
+
t = t.float()
|
|
196
|
+
data = t.numpy()
|
|
197
|
+
|
|
198
|
+
return data
|
|
199
|
+
except Exception as e3:
|
|
200
|
+
logger.warning(f"path_C_failed: {e3}.")
|
|
201
|
+
return memoryview(b"")
|
|
202
|
+
|
|
203
|
+
@staticmethod
|
|
204
|
+
def compute_crc32_from_tensor(t: torch.Tensor) -> str:
|
|
205
|
+
"""
|
|
206
|
+
直接对 Tensor 原始字节做 CRC32。
|
|
207
|
+
:
|
|
208
|
+
- "raw": 保持 bfloat16 原始 16bit 字节(推荐,避免升精/增容)
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
# 取得字节视图(含多级回退),然后做 CRC
|
|
212
|
+
mv = PytorchDataProcessor.tensor_bytes_view_cpu(t)
|
|
213
|
+
|
|
214
|
+
crc = zlib.crc32(mv)
|
|
215
|
+
|
|
216
|
+
return f"{crc:08x}"
|
|
217
|
+
|
|
77
218
|
@staticmethod
|
|
78
219
|
def analyze_device_in_kwargs(element):
|
|
79
220
|
single_arg = {}
|
|
@@ -94,80 +235,6 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
94
235
|
def analyze_dtype_in_kwargs(element):
|
|
95
236
|
return {"type": "torch.dtype", "value": str(element)}
|
|
96
237
|
|
|
97
|
-
@staticmethod
|
|
98
|
-
def get_stat_info_async(data):
|
|
99
|
-
tensor_stat = TensorStatInfo()
|
|
100
|
-
if torch.is_complex(data):
|
|
101
|
-
logger.warning("Async dump do not support complex data!")
|
|
102
|
-
return tensor_stat
|
|
103
|
-
elif data.dtype == torch.bool:
|
|
104
|
-
tensor_stat.max = torch.any(data)
|
|
105
|
-
tensor_stat.min = torch.all(data)
|
|
106
|
-
elif not data.shape:
|
|
107
|
-
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
|
|
108
|
-
else:
|
|
109
|
-
if data.dtype == torch.float64 or not data.is_floating_point():
|
|
110
|
-
data = data.float()
|
|
111
|
-
tensor_stat.max = torch.max(data)
|
|
112
|
-
tensor_stat.min = torch.min(data)
|
|
113
|
-
tensor_stat.mean = torch.mean(data)
|
|
114
|
-
tensor_stat.norm = torch.norm(data)
|
|
115
|
-
return tensor_stat
|
|
116
|
-
|
|
117
|
-
@staticmethod
|
|
118
|
-
def get_stat_info_sync(data):
|
|
119
|
-
tensor_stat = TensorStatInfo()
|
|
120
|
-
if torch.is_complex(data):
|
|
121
|
-
data_np = data.cpu().numpy()
|
|
122
|
-
data_abs = np.abs(data_np)
|
|
123
|
-
tensor_stat.max = np.max(data_abs).item()
|
|
124
|
-
tensor_stat.min = np.min(data_abs).item()
|
|
125
|
-
tensor_stat.mean = np.mean(data_abs).item()
|
|
126
|
-
elif data.dtype == torch.bool:
|
|
127
|
-
tensor_stat.max = torch.any(data)
|
|
128
|
-
tensor_stat.min = torch.all(data)
|
|
129
|
-
elif not data.shape:
|
|
130
|
-
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
|
|
131
|
-
else:
|
|
132
|
-
if data.dtype == torch.float64 or not data.is_floating_point():
|
|
133
|
-
data = data.float()
|
|
134
|
-
tensor_stat.max = torch.max(data)
|
|
135
|
-
tensor_stat.min = torch.min(data)
|
|
136
|
-
tensor_stat.mean = torch.mean(data)
|
|
137
|
-
tensor_stat.norm = torch.norm(data)
|
|
138
|
-
return tensor_stat
|
|
139
|
-
|
|
140
|
-
@staticmethod
|
|
141
|
-
def get_stat_info(data, async_dump=False):
|
|
142
|
-
tensor_stat = TensorStatInfo()
|
|
143
|
-
if data.is_meta:
|
|
144
|
-
return tensor_stat
|
|
145
|
-
data_clone = data.detach()
|
|
146
|
-
if not data_clone.numel() or not data_clone.data_ptr():
|
|
147
|
-
return tensor_stat
|
|
148
|
-
else:
|
|
149
|
-
if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump:
|
|
150
|
-
return PytorchDataProcessor.get_stat_info_sync(data_clone)
|
|
151
|
-
else:
|
|
152
|
-
return PytorchDataProcessor.get_stat_info_async(data_clone)
|
|
153
|
-
|
|
154
|
-
@staticmethod
|
|
155
|
-
def handle_tensor_extremum_nan_inf(tensor, operator):
|
|
156
|
-
data_clone = tensor.detach()
|
|
157
|
-
data_nan = torch.isnan(data_clone)
|
|
158
|
-
if int(torch.sum(data_nan)) == data_clone.numel():
|
|
159
|
-
return float('nan')
|
|
160
|
-
|
|
161
|
-
finite_mask = torch.isfinite(data_clone)
|
|
162
|
-
if int(torch.sum(finite_mask)) > 0:
|
|
163
|
-
finite_values = data_clone[finite_mask]
|
|
164
|
-
return torch.max(finite_values).item() if operator == 'max' else \
|
|
165
|
-
torch.min(finite_values).item()
|
|
166
|
-
else:
|
|
167
|
-
data_no_nan = data_clone[~data_nan]
|
|
168
|
-
return torch.max(data_no_nan).item() if operator == 'max' else \
|
|
169
|
-
torch.min(data_no_nan).item()
|
|
170
|
-
|
|
171
238
|
@staticmethod
|
|
172
239
|
def process_group_hash(arg):
|
|
173
240
|
group_ranks = dist.get_process_group_ranks(arg)
|
|
@@ -214,9 +281,40 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
214
281
|
def get_special_types(cls):
|
|
215
282
|
return super().get_special_types() + cls.pytorch_special_type
|
|
216
283
|
|
|
284
|
+
def get_stat_info(self, data, async_dump=False, precision=Const.DUMP_PRECISION_LOW):
|
|
285
|
+
tensor_stat = TensorStatInfo()
|
|
286
|
+
if self.tensor_handler.is_empty_data(data):
|
|
287
|
+
return tensor_stat
|
|
288
|
+
data_clone = data.detach()
|
|
289
|
+
if not data_clone.numel() or not data_clone.data_ptr():
|
|
290
|
+
return tensor_stat
|
|
291
|
+
if torch.is_complex(data_clone):
|
|
292
|
+
if async_dump:
|
|
293
|
+
logger.warning("Async dump do not support complex data!")
|
|
294
|
+
return tensor_stat
|
|
295
|
+
data_np = data_clone.cpu().numpy()
|
|
296
|
+
data_abs = np.abs(data_np)
|
|
297
|
+
tensor_stat.max = np.max(data_abs).item()
|
|
298
|
+
tensor_stat.min = np.min(data_abs).item()
|
|
299
|
+
tensor_stat.mean = np.mean(data_abs).item()
|
|
300
|
+
elif data_clone.dtype == torch.bool:
|
|
301
|
+
tensor_stat.max = torch.any(data_clone)
|
|
302
|
+
tensor_stat.min = torch.all(data_clone)
|
|
303
|
+
elif not data_clone.shape:
|
|
304
|
+
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.clone()
|
|
305
|
+
else:
|
|
306
|
+
if (precision == Const.DUMP_PRECISION_HIGH or data_clone.dtype == torch.float64
|
|
307
|
+
or not data_clone.is_floating_point()):
|
|
308
|
+
data_clone = data_clone.float()
|
|
309
|
+
tensor_stat.max = torch.max(data_clone)
|
|
310
|
+
tensor_stat.min = torch.min(data_clone)
|
|
311
|
+
tensor_stat.mean = torch.mean(data_clone)
|
|
312
|
+
tensor_stat.norm = torch.norm(data_clone)
|
|
313
|
+
return tensor_stat
|
|
314
|
+
|
|
217
315
|
def dump_async_data(self):
|
|
218
316
|
for file_path, tensor in self._async_dump_cache.items():
|
|
219
|
-
|
|
317
|
+
self.tensor_handler.save_tensor(tensor, file_path)
|
|
220
318
|
self._async_dump_cache.clear()
|
|
221
319
|
|
|
222
320
|
def analyze_single_element(self, element, suffix_stack):
|
|
@@ -256,11 +354,12 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
256
354
|
return p2pop_info
|
|
257
355
|
|
|
258
356
|
def _analyze_tensor(self, tensor, suffix):
|
|
259
|
-
|
|
357
|
+
common_tensor = self.tensor_handler.convert_common_tensor(tensor)
|
|
358
|
+
tensor_stat = self.get_stat_info(common_tensor, self.config.async_dump, self.config.precision)
|
|
260
359
|
tensor_json = {}
|
|
261
|
-
tensor_json.update({'type':
|
|
262
|
-
tensor_json.update({'dtype': str(
|
|
263
|
-
tensor_json.update({"shape":
|
|
360
|
+
tensor_json.update({'type': self.tensor_handler.get_tensor_type(tensor)})
|
|
361
|
+
tensor_json.update({'dtype': str(common_tensor.dtype)})
|
|
362
|
+
tensor_json.update({"shape": common_tensor.shape})
|
|
264
363
|
|
|
265
364
|
stat_values = [
|
|
266
365
|
tensor_stat.max,
|
|
@@ -272,26 +371,64 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
272
371
|
|
|
273
372
|
tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index})
|
|
274
373
|
tensor_json.update({"requires_grad": tensor.requires_grad})
|
|
374
|
+
if self.tensor_handler.is_dtensor(tensor):
|
|
375
|
+
dtensor_info = self.tensor_handler.get_dtensor_info(tensor)
|
|
376
|
+
tensor_json.update(dtensor_info)
|
|
275
377
|
|
|
276
378
|
if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
|
|
277
|
-
tensor_md5 =
|
|
278
|
-
|
|
379
|
+
tensor_md5 = None
|
|
380
|
+
if not self.tensor_handler.is_empty_data(tensor):
|
|
381
|
+
t_cpu = common_tensor
|
|
382
|
+
|
|
383
|
+
# 根据设备类型做同步,确保数据已准备好
|
|
384
|
+
if t_cpu.device.type == "cuda":
|
|
385
|
+
t_cpu = t_cpu.to("cpu", non_blocking=True)
|
|
386
|
+
torch.cuda.synchronize()
|
|
387
|
+
# 先异步搬运再进行同步可以显著提升性能
|
|
388
|
+
elif t_cpu.device.type == "npu":
|
|
389
|
+
t_cpu = t_cpu.to("cpu", non_blocking=True)
|
|
390
|
+
torch.npu.synchronize()
|
|
391
|
+
|
|
392
|
+
t_cpu = t_cpu.detach()
|
|
393
|
+
if not t_cpu.is_contiguous():
|
|
394
|
+
t_cpu = t_cpu.contiguous()
|
|
395
|
+
|
|
396
|
+
future = self._crc_executor.submit(
|
|
397
|
+
PytorchDataProcessor.compute_crc32_from_tensor,
|
|
398
|
+
t_cpu
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
crc_placeholder = self.data_writer.append_crc32_to_buffer(future)
|
|
402
|
+
tensor_json[Const.MD5_INDEX] = crc_placeholder
|
|
403
|
+
else:
|
|
404
|
+
logger.debug(
|
|
405
|
+
"Calculating the md5 value of fake tensor or meta tensor is not supported, "
|
|
406
|
+
f"the current api/module name is {self.current_api_or_module_name}."
|
|
407
|
+
)
|
|
408
|
+
tensor_json.update({Const.MD5: tensor_md5})
|
|
279
409
|
return tensor_json
|
|
280
410
|
|
|
281
411
|
def _analyze_and_save_tensor(self, tensor, suffix):
|
|
282
412
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
283
413
|
single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix)
|
|
414
|
+
common_tensor = self.tensor_handler.convert_common_tensor(tensor)
|
|
415
|
+
if self.tensor_handler.is_empty_data(common_tensor):
|
|
416
|
+
logger.debug(f"Saving fake tensor or meta tensor is not supported, the current tensor is {file_path}.")
|
|
417
|
+
return single_arg
|
|
418
|
+
if common_tensor.untyped_storage().data_ptr() == 0:
|
|
419
|
+
logger.debug(f"Saving null-pointer tensor is not supported, the current tensor is {file_path}.")
|
|
420
|
+
return single_arg
|
|
421
|
+
|
|
284
422
|
single_arg.update({"data_name": dump_data_name})
|
|
285
423
|
if self.config.async_dump:
|
|
286
|
-
self._async_dump_cache[file_path] =
|
|
424
|
+
self._async_dump_cache[file_path] = common_tensor.clone().detach()
|
|
287
425
|
else:
|
|
288
|
-
|
|
289
|
-
save_pt(saved_tensor, file_path)
|
|
426
|
+
self.tensor_handler.save_tensor(common_tensor, file_path)
|
|
290
427
|
return single_arg
|
|
291
428
|
|
|
292
429
|
def _analyze_and_save_ndarray(self, ndarray, suffix):
|
|
293
430
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
294
|
-
|
|
431
|
+
self.tensor_handler.save_tensor(torch.tensor(ndarray), file_path)
|
|
295
432
|
ndarray_json = PytorchDataProcessor._analyze_ndarray(ndarray, suffix)
|
|
296
433
|
ndarray_json.update({"data_name": dump_data_name})
|
|
297
434
|
return ndarray_json
|
|
@@ -382,7 +519,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
382
519
|
self._analyze_maybe_overflow_flag()
|
|
383
520
|
if self.has_overflow:
|
|
384
521
|
for file_path, tensor in self.cached_tensors_and_file_paths.items():
|
|
385
|
-
|
|
522
|
+
self.tensor_handler.save_tensor(tensor, file_path)
|
|
386
523
|
self.real_overflow_nums += 1
|
|
387
524
|
if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums:
|
|
388
525
|
logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, "
|
|
@@ -427,10 +564,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
427
564
|
|
|
428
565
|
def _analyze_tensor(self, tensor, suffix):
|
|
429
566
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
430
|
-
|
|
431
|
-
self.cached_tensors_and_file_paths.update({file_path: tensor})
|
|
432
|
-
else:
|
|
433
|
-
logger.warning(f'The file path {file_path} length exceeds limit.')
|
|
567
|
+
self.cached_tensors_and_file_paths.update({file_path: tensor})
|
|
434
568
|
single_arg = super()._analyze_tensor(tensor, suffix)
|
|
435
569
|
single_arg.update({"data_name": dump_data_name})
|
|
436
570
|
if not self.has_overflow and self.support_inf_nan:
|