mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -17,13 +17,14 @@ import zlib
|
|
|
17
17
|
|
|
18
18
|
import mindspore as ms
|
|
19
19
|
from mindspore import mint, ops, hal
|
|
20
|
+
from mindspore.mint import distributed
|
|
20
21
|
from mindspore._c_expression.typing import Number
|
|
21
22
|
import numpy as np
|
|
22
23
|
|
|
23
24
|
from msprobe.core.common.const import Const
|
|
24
25
|
from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo,
|
|
25
26
|
ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs)
|
|
26
|
-
from msprobe.core.common.file_utils import path_len_exceeds_limit
|
|
27
|
+
from msprobe.core.common.file_utils import path_len_exceeds_limit
|
|
27
28
|
from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy
|
|
28
29
|
from msprobe.mindspore.common.log import logger
|
|
29
30
|
from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
|
|
@@ -36,7 +37,7 @@ except ImportError:
|
|
|
36
37
|
|
|
37
38
|
|
|
38
39
|
class MindsporeDataProcessor(BaseDataProcessor):
|
|
39
|
-
mindspore_special_type = tuple([ms.Tensor, Number])
|
|
40
|
+
mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp])
|
|
40
41
|
|
|
41
42
|
def __init__(self, config, data_writer):
|
|
42
43
|
super().__init__(config, data_writer)
|
|
@@ -65,7 +66,7 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
65
66
|
tensor_stat.max = np.max(data_np).item()
|
|
66
67
|
tensor_stat.min = np.min(data_np).item()
|
|
67
68
|
elif not data.shape:
|
|
68
|
-
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
|
|
69
|
+
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
|
|
69
70
|
elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
|
|
70
71
|
data_abs = np.abs(data.asnumpy())
|
|
71
72
|
tensor_stat.max = np.max(data_abs).item()
|
|
@@ -76,38 +77,52 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
76
77
|
if not ops.is_floating_point(data) or data.dtype == ms.float64:
|
|
77
78
|
data = data.to(ms.float32)
|
|
78
79
|
get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
|
|
79
|
-
tensor_stat.max = mint.max(data)
|
|
80
|
-
tensor_stat.min = mint.min(data)
|
|
81
|
-
tensor_stat.mean = mint.mean(data)
|
|
82
|
-
tensor_stat.norm = get_norm_value(data)
|
|
80
|
+
tensor_stat.max = mint.max(data)
|
|
81
|
+
tensor_stat.min = mint.min(data)
|
|
82
|
+
tensor_stat.mean = mint.mean(data)
|
|
83
|
+
tensor_stat.norm = get_norm_value(data)
|
|
83
84
|
return tensor_stat
|
|
84
85
|
|
|
85
86
|
@staticmethod
|
|
86
87
|
def get_stat_info_async(data):
|
|
87
88
|
tensor_stat = TensorStatInfo()
|
|
88
|
-
if data.dtype == ms.
|
|
89
|
+
if data.dtype == ms.bool_:
|
|
90
|
+
tensor_stat.max = mint.any(data)
|
|
91
|
+
tensor_stat.min = mint.all(data)
|
|
92
|
+
elif not data.shape:
|
|
93
|
+
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
|
|
94
|
+
elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
|
|
89
95
|
logger.warning("Async dump do not support complex data!")
|
|
90
96
|
return tensor_stat
|
|
91
|
-
elif data.dtype == ms.bool_:
|
|
92
|
-
tensor_stat.stack_tensor_stat = (["Max", "Min"], ops.stack([data.any(), data.all()]))
|
|
93
|
-
elif not data.shape:
|
|
94
|
-
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], ops.stack([data, data, data, data]))
|
|
95
97
|
else:
|
|
96
98
|
if not ops.is_floating_point(data) or data.dtype == ms.float64:
|
|
97
99
|
data = data.to(ms.float32)
|
|
98
100
|
get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
|
|
99
|
-
tensor_stat.
|
|
100
|
-
|
|
101
|
+
tensor_stat.max = mint.max(data)
|
|
102
|
+
tensor_stat.min = mint.min(data)
|
|
103
|
+
tensor_stat.mean = mint.mean(data)
|
|
104
|
+
tensor_stat.norm = get_norm_value(data)
|
|
101
105
|
return tensor_stat
|
|
102
106
|
|
|
103
107
|
@staticmethod
|
|
104
108
|
def is_hookable_element(element):
|
|
105
109
|
return hasattr(element, "register_hook") and callable(element.register_hook)
|
|
106
110
|
|
|
111
|
+
@staticmethod
|
|
112
|
+
def process_group_hash(arg):
|
|
113
|
+
group_ranks = distributed.get_process_group_ranks(arg)
|
|
114
|
+
group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8'))
|
|
115
|
+
return f"{group_ranks_hash:08x}"
|
|
116
|
+
|
|
107
117
|
@classmethod
|
|
108
118
|
def get_special_types(cls):
|
|
109
119
|
return super().get_special_types() + cls.mindspore_special_type
|
|
110
120
|
|
|
121
|
+
def dump_async_data(self):
|
|
122
|
+
for file_path, tensor in self._async_dump_cache.items():
|
|
123
|
+
save_tensor_as_npy(tensor, file_path)
|
|
124
|
+
self._async_dump_cache.clear()
|
|
125
|
+
|
|
111
126
|
def get_stat_info(self, data):
|
|
112
127
|
self.api_register.restore_inner_used_api()
|
|
113
128
|
tensor_stat = TensorStatInfo()
|
|
@@ -125,19 +140,34 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
125
140
|
if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
|
|
126
141
|
return self.mindspore_object_key[suffix_stack[-1]](element)
|
|
127
142
|
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
143
|
+
suffix_str = Const.SEP.join(str(s) for s in suffix_stack)
|
|
144
|
+
type_analyzer = [
|
|
145
|
+
(MindsporeDataProcessor.builtin_type, self._analyze_builtin),
|
|
146
|
+
(ms.Tensor, lambda e: self._analyze_tensor(e, suffix_str)),
|
|
147
|
+
(Number, self.analyze_dtype_in_kwargs),
|
|
148
|
+
(MindsporeDataProcessor.np_type[:-1], self._analyze_numpy),
|
|
149
|
+
(np.ndarray, lambda e: self._analyze_ndarray(e, suffix_str)),
|
|
150
|
+
(distributed.P2POp, lambda e: self._analyze_p2pop(e, suffix_str))
|
|
151
|
+
]
|
|
152
|
+
for type_key, analyze_fn in type_analyzer:
|
|
153
|
+
if isinstance(element, type_key):
|
|
154
|
+
return analyze_fn(element)
|
|
139
155
|
return {}
|
|
140
156
|
|
|
157
|
+
def _analyze_p2pop(self, arg, suffix):
|
|
158
|
+
p2pop_info = {"class_type": "mindspore.mint.distributed.P2POp"}
|
|
159
|
+
try:
|
|
160
|
+
tensor_info = self._analyze_tensor(arg.tensor, suffix)
|
|
161
|
+
p2pop_info.update({"tensor": tensor_info})
|
|
162
|
+
p2pop_info.update({"op": arg.op})
|
|
163
|
+
p2pop_info.update({"peer": arg.peer})
|
|
164
|
+
p2pop_info.update({"tag": arg.tag})
|
|
165
|
+
group_id = self.process_group_hash(arg.group) if arg.group else None
|
|
166
|
+
p2pop_info.update({"group_id": group_id})
|
|
167
|
+
except Exception as e:
|
|
168
|
+
logger.warning(f"Failed to parse the P2POp content with error info: {e}.")
|
|
169
|
+
return p2pop_info
|
|
170
|
+
|
|
141
171
|
def _analyze_tensor(self, tensor, suffix):
|
|
142
172
|
tensor_stat = self.get_stat_info(tensor)
|
|
143
173
|
tensor_json = {
|
|
@@ -146,32 +176,26 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
146
176
|
'shape': tensor.shape
|
|
147
177
|
}
|
|
148
178
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
179
|
+
# 将统计值存入全局 buffer,并返回占位索引
|
|
180
|
+
stat_values = [
|
|
181
|
+
tensor_stat.max,
|
|
182
|
+
tensor_stat.min,
|
|
183
|
+
tensor_stat.mean,
|
|
184
|
+
tensor_stat.norm
|
|
185
|
+
]
|
|
186
|
+
|
|
187
|
+
placeholder_index = self.data_writer.append_stat_to_buffer(stat_values)
|
|
188
|
+
|
|
189
|
+
tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index})
|
|
190
|
+
|
|
156
191
|
if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
|
|
157
192
|
tensor_md5 = self.get_md5_for_tensor(tensor)
|
|
158
193
|
tensor_json.update({Const.MD5: tensor_md5})
|
|
159
194
|
return tensor_json
|
|
160
195
|
|
|
161
|
-
|
|
162
|
-
class StatisticsDataProcessor(MindsporeDataProcessor):
|
|
163
|
-
pass
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
class TensorDataProcessor(MindsporeDataProcessor):
|
|
167
|
-
def dump_async_data(self):
|
|
168
|
-
for file_path, tensor in self._async_dump_cache.items():
|
|
169
|
-
save_tensor_as_npy(tensor, file_path)
|
|
170
|
-
self._async_dump_cache.clear()
|
|
171
|
-
|
|
172
|
-
def _analyze_tensor(self, tensor, suffix):
|
|
196
|
+
def _analyze_and_save_tensor(self, tensor, suffix):
|
|
173
197
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
174
|
-
single_arg =
|
|
198
|
+
single_arg = MindsporeDataProcessor._analyze_tensor(self, tensor, suffix)
|
|
175
199
|
single_arg.update({"data_name": dump_data_name})
|
|
176
200
|
if self.config.async_dump:
|
|
177
201
|
self._async_dump_cache[file_path] = tensor.copy()
|
|
@@ -179,12 +203,27 @@ class TensorDataProcessor(MindsporeDataProcessor):
|
|
|
179
203
|
save_tensor_as_npy(tensor, file_path)
|
|
180
204
|
return single_arg
|
|
181
205
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
206
|
+
|
|
207
|
+
class StatisticsDataProcessor(MindsporeDataProcessor):
|
|
208
|
+
def _analyze_tensor(self, tensor, suffix):
|
|
209
|
+
if any(item in self.current_api_or_module_name for item in self.config.tensor_list):
|
|
210
|
+
return self._analyze_and_save_tensor(tensor, suffix)
|
|
211
|
+
else:
|
|
212
|
+
return super()._analyze_tensor(tensor, suffix)
|
|
213
|
+
|
|
214
|
+
def _analyze_ndarray(self, ndarray, suffix):
|
|
215
|
+
if any(item in self.current_api_or_module_name for item in self.config.tensor_list):
|
|
216
|
+
return self._analyze_and_save_ndarray(ndarray, suffix)
|
|
217
|
+
else:
|
|
218
|
+
return super()._analyze_ndarray(ndarray, suffix)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class TensorDataProcessor(MindsporeDataProcessor):
|
|
222
|
+
def _analyze_tensor(self, tensor, suffix):
|
|
223
|
+
return self._analyze_and_save_tensor(tensor, suffix)
|
|
224
|
+
|
|
225
|
+
def _analyze_ndarray(self, ndarray, suffix):
|
|
226
|
+
return self._analyze_and_save_ndarray(ndarray, suffix)
|
|
188
227
|
|
|
189
228
|
|
|
190
229
|
class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
@@ -231,7 +270,7 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
231
270
|
api_info_struct = super().analyze_backward(name, module, module_input_output)
|
|
232
271
|
self.maybe_save_overflow_data()
|
|
233
272
|
return api_info_struct if self.has_overflow else None
|
|
234
|
-
|
|
273
|
+
|
|
235
274
|
def analyze_params(self, name, param_name, grad):
|
|
236
275
|
self.has_overflow = False
|
|
237
276
|
api_info_struct = super().analyze_params(name, param_name, grad)
|
|
@@ -249,11 +288,26 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
249
288
|
self.cached_tensors_and_file_paths = {}
|
|
250
289
|
|
|
251
290
|
def _analyze_maybe_overflow_tensor(self, tensor_json):
|
|
252
|
-
|
|
291
|
+
tensor_stat_index = tensor_json.get(Const.TENSOR_STAT_INDEX)
|
|
292
|
+
if tensor_stat_index is None:
|
|
293
|
+
logger.warning("tensor_stat_index does not exist in tensor_json.")
|
|
253
294
|
return
|
|
254
|
-
|
|
295
|
+
max_tensor = self.data_writer.get_buffer_values_max(tensor_stat_index)
|
|
296
|
+
min_tensor = self.data_writer.get_buffer_values_min(tensor_stat_index)
|
|
297
|
+
if max_tensor is None or min_tensor is None:
|
|
298
|
+
return
|
|
299
|
+
|
|
300
|
+
def check_inf_nan(value):
|
|
301
|
+
# Use .item() if it's a tensor-like structure
|
|
302
|
+
if hasattr(value, "item"):
|
|
303
|
+
value = value.item()
|
|
304
|
+
return np.isinf(value) or np.isnan(value)
|
|
305
|
+
|
|
306
|
+
if check_inf_nan(max_tensor):
|
|
255
307
|
self.has_overflow = True
|
|
256
|
-
|
|
308
|
+
return
|
|
309
|
+
|
|
310
|
+
if check_inf_nan(min_tensor):
|
|
257
311
|
self.has_overflow = True
|
|
258
312
|
|
|
259
313
|
def _analyze_tensor(self, tensor, suffix):
|
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import hashlib
|
|
17
16
|
import zlib
|
|
18
17
|
from dataclasses import asdict
|
|
19
18
|
from typing import List
|
|
@@ -102,19 +101,17 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
102
101
|
logger.warning("Async dump do not support complex data!")
|
|
103
102
|
return tensor_stat
|
|
104
103
|
elif data.dtype == torch.bool:
|
|
105
|
-
tensor_stat.
|
|
106
|
-
|
|
104
|
+
tensor_stat.max = torch.any(data)
|
|
105
|
+
tensor_stat.min = torch.all(data)
|
|
107
106
|
elif not data.shape:
|
|
108
|
-
tensor_stat.
|
|
107
|
+
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
|
|
109
108
|
else:
|
|
110
|
-
if
|
|
109
|
+
if data.dtype == torch.float64 or not data.is_floating_point():
|
|
111
110
|
data = data.float()
|
|
112
|
-
tensor_stat.
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
torch.norm(data)
|
|
117
|
-
]))
|
|
111
|
+
tensor_stat.max = torch.max(data)
|
|
112
|
+
tensor_stat.min = torch.min(data)
|
|
113
|
+
tensor_stat.mean = torch.mean(data)
|
|
114
|
+
tensor_stat.norm = torch.norm(data)
|
|
118
115
|
return tensor_stat
|
|
119
116
|
|
|
120
117
|
@staticmethod
|
|
@@ -127,17 +124,17 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
127
124
|
tensor_stat.min = np.min(data_abs).item()
|
|
128
125
|
tensor_stat.mean = np.mean(data_abs).item()
|
|
129
126
|
elif data.dtype == torch.bool:
|
|
130
|
-
tensor_stat.max = torch.any(data)
|
|
131
|
-
tensor_stat.min = torch.all(data)
|
|
127
|
+
tensor_stat.max = torch.any(data)
|
|
128
|
+
tensor_stat.min = torch.all(data)
|
|
132
129
|
elif not data.shape:
|
|
133
|
-
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
|
|
130
|
+
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
|
|
134
131
|
else:
|
|
135
|
-
if
|
|
132
|
+
if data.dtype == torch.float64 or not data.is_floating_point():
|
|
136
133
|
data = data.float()
|
|
137
|
-
tensor_stat.max = torch.max(data)
|
|
138
|
-
tensor_stat.min = torch.min(data)
|
|
139
|
-
tensor_stat.mean = torch.mean(data)
|
|
140
|
-
tensor_stat.norm = torch.norm(data)
|
|
134
|
+
tensor_stat.max = torch.max(data)
|
|
135
|
+
tensor_stat.min = torch.min(data)
|
|
136
|
+
tensor_stat.mean = torch.mean(data)
|
|
137
|
+
tensor_stat.norm = torch.norm(data)
|
|
141
138
|
return tensor_stat
|
|
142
139
|
|
|
143
140
|
@staticmethod
|
|
@@ -174,12 +171,8 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
174
171
|
@staticmethod
|
|
175
172
|
def process_group_hash(arg):
|
|
176
173
|
group_ranks = dist.get_process_group_ranks(arg)
|
|
177
|
-
group_ranks_hash =
|
|
178
|
-
return group_ranks_hash
|
|
179
|
-
|
|
180
|
-
@staticmethod
|
|
181
|
-
def is_distributed_op(module):
|
|
182
|
-
return getattr(module, "op_is_distributed", False)
|
|
174
|
+
group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8'))
|
|
175
|
+
return f"{group_ranks_hash:08x}"
|
|
183
176
|
|
|
184
177
|
@staticmethod
|
|
185
178
|
def is_hookable_element(element):
|
|
@@ -233,34 +226,31 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
233
226
|
def get_special_types(cls):
|
|
234
227
|
return super().get_special_types() + cls.pytorch_special_type
|
|
235
228
|
|
|
229
|
+
def dump_async_data(self):
|
|
230
|
+
for file_path, tensor in self._async_dump_cache.items():
|
|
231
|
+
save_pt(tensor.contiguous(), file_path)
|
|
232
|
+
self._async_dump_cache.clear()
|
|
233
|
+
|
|
236
234
|
def analyze_single_element(self, element, suffix_stack):
|
|
237
235
|
if suffix_stack and suffix_stack[-1] in self.torch_object_key:
|
|
238
236
|
return self.torch_object_key[suffix_stack[-1]](element)
|
|
239
|
-
if isinstance(element, torch.Size):
|
|
240
|
-
return self._analyze_torch_size(element)
|
|
241
|
-
if isinstance(element, torch.memory_format):
|
|
242
|
-
return self._analyze_memory_format(element)
|
|
243
|
-
if isinstance(element, dist.ProcessGroup):
|
|
244
|
-
return self._analyze_process_group(element)
|
|
245
|
-
if isinstance(element, dist.P2POp):
|
|
246
|
-
return self._analyze_p2pop(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
|
|
247
|
-
if isinstance(element, dist.ReduceOp):
|
|
248
|
-
return self._analyze_reduce_op(element)
|
|
249
|
-
converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
|
|
250
|
-
if converted_numpy is not element:
|
|
251
|
-
return {"type": numpy_type, "value": converted_numpy}
|
|
252
|
-
if isinstance(element, torch.Tensor):
|
|
253
|
-
return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
|
|
254
|
-
if isinstance(element, np.ndarray):
|
|
255
|
-
return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
|
|
256
|
-
if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
|
|
257
|
-
return self._analyze_builtin(element)
|
|
258
|
-
return {}
|
|
259
237
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
238
|
+
suffix_str = Const.SEP.join(str(s) for s in suffix_stack)
|
|
239
|
+
type_analyzer = [
|
|
240
|
+
(PytorchDataProcessor.builtin_type, self._analyze_builtin),
|
|
241
|
+
(torch.Size, self._analyze_torch_size),
|
|
242
|
+
(torch.Tensor, lambda e: self._analyze_tensor(e, suffix_str)),
|
|
243
|
+
(torch.memory_format, self._analyze_memory_format),
|
|
244
|
+
(dist.ProcessGroup, self._analyze_process_group),
|
|
245
|
+
(dist.P2POp, lambda e: self._analyze_p2pop(e, suffix_str)),
|
|
246
|
+
(dist.ReduceOp, self._analyze_reduce_op),
|
|
247
|
+
(PytorchDataProcessor.np_type[:-1], self._analyze_numpy),
|
|
248
|
+
(np.ndarray, lambda e: self._analyze_ndarray(e, suffix_str)),
|
|
249
|
+
]
|
|
250
|
+
for type_key, analyze_fn in type_analyzer:
|
|
251
|
+
if isinstance(element, type_key):
|
|
252
|
+
return analyze_fn(element)
|
|
253
|
+
return {}
|
|
264
254
|
|
|
265
255
|
def _analyze_p2pop(self, arg, suffix):
|
|
266
256
|
p2pop_info = {"class_type": "torch.distributed.P2POp"}
|
|
@@ -284,42 +274,26 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
284
274
|
tensor_json.update({'type': 'torch.Tensor'})
|
|
285
275
|
tensor_json.update({'dtype': dtype})
|
|
286
276
|
tensor_json.update({"shape": tensor.shape})
|
|
287
|
-
if tensor_stat.stack_tensor_stat is None:
|
|
288
|
-
tensor_json.update({"Max": tensor_stat.max})
|
|
289
|
-
tensor_json.update({"Min": tensor_stat.min})
|
|
290
|
-
tensor_json.update({"Mean": tensor_stat.mean})
|
|
291
|
-
tensor_json.update({"Norm": tensor_stat.norm})
|
|
292
|
-
tensor_json.update({"requires_grad": tensor.requires_grad})
|
|
293
|
-
if tensor_stat.max is not None:
|
|
294
|
-
if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
|
|
295
|
-
tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
|
|
296
|
-
if tensor_stat.min is not None:
|
|
297
|
-
if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
|
|
298
|
-
tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
|
|
299
277
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
278
|
+
stat_values = [
|
|
279
|
+
tensor_stat.max,
|
|
280
|
+
tensor_stat.min,
|
|
281
|
+
tensor_stat.mean,
|
|
282
|
+
tensor_stat.norm
|
|
283
|
+
]
|
|
284
|
+
placeholder_index = self.data_writer.append_stat_to_buffer(stat_values)
|
|
285
|
+
|
|
286
|
+
tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index})
|
|
287
|
+
tensor_json.update({"requires_grad": tensor.requires_grad})
|
|
303
288
|
|
|
304
289
|
if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
|
|
305
290
|
tensor_md5 = self.get_md5_for_tensor(tensor)
|
|
306
291
|
tensor_json.update({Const.MD5: tensor_md5})
|
|
307
292
|
return tensor_json
|
|
308
293
|
|
|
309
|
-
|
|
310
|
-
class StatisticsDataProcessor(PytorchDataProcessor):
|
|
311
|
-
pass
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
class TensorDataProcessor(PytorchDataProcessor):
|
|
315
|
-
def dump_async_data(self):
|
|
316
|
-
for file_path, tensor in self._async_dump_cache.items():
|
|
317
|
-
save_pt(tensor.contiguous(), file_path)
|
|
318
|
-
self._async_dump_cache.clear()
|
|
319
|
-
|
|
320
|
-
def _analyze_tensor(self, tensor, suffix):
|
|
294
|
+
def _analyze_and_save_tensor(self, tensor, suffix):
|
|
321
295
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
322
|
-
single_arg =
|
|
296
|
+
single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix)
|
|
323
297
|
single_arg.update({"data_name": dump_data_name})
|
|
324
298
|
tensor, _ = self._cast_to_float_if_fp8(tensor)
|
|
325
299
|
if self.config.async_dump:
|
|
@@ -329,14 +303,36 @@ class TensorDataProcessor(PytorchDataProcessor):
|
|
|
329
303
|
save_pt(saved_tensor, file_path)
|
|
330
304
|
return single_arg
|
|
331
305
|
|
|
332
|
-
def
|
|
306
|
+
def _analyze_and_save_ndarray(self, ndarray, suffix):
|
|
333
307
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
334
308
|
save_pt(torch.tensor(ndarray), file_path)
|
|
335
|
-
ndarray_json =
|
|
309
|
+
ndarray_json = PytorchDataProcessor._analyze_ndarray(ndarray, suffix)
|
|
336
310
|
ndarray_json.update({"data_name": dump_data_name})
|
|
337
311
|
return ndarray_json
|
|
338
312
|
|
|
339
313
|
|
|
314
|
+
class StatisticsDataProcessor(PytorchDataProcessor):
|
|
315
|
+
def _analyze_tensor(self, tensor, suffix):
|
|
316
|
+
if any(item in self.current_api_or_module_name for item in self.config.tensor_list):
|
|
317
|
+
return self._analyze_and_save_tensor(tensor, suffix)
|
|
318
|
+
else:
|
|
319
|
+
return super()._analyze_tensor(tensor, suffix)
|
|
320
|
+
|
|
321
|
+
def _analyze_ndarray(self, ndarray, suffix):
|
|
322
|
+
if any(item in self.current_api_or_module_name for item in self.config.tensor_list):
|
|
323
|
+
return self._analyze_and_save_ndarray(ndarray, suffix)
|
|
324
|
+
else:
|
|
325
|
+
return super()._analyze_ndarray(ndarray, suffix)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
class TensorDataProcessor(PytorchDataProcessor):
|
|
329
|
+
def _analyze_tensor(self, tensor, suffix):
|
|
330
|
+
return self._analyze_and_save_tensor(tensor, suffix)
|
|
331
|
+
|
|
332
|
+
def _analyze_ndarray(self, ndarray, suffix):
|
|
333
|
+
return self._analyze_and_save_ndarray(ndarray, suffix)
|
|
334
|
+
|
|
335
|
+
|
|
340
336
|
class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
341
337
|
__slots__ = ["cached_tensors_and_file_paths"]
|
|
342
338
|
|
|
@@ -427,10 +423,22 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
427
423
|
raise RuntimeError(f"overflow check failed") from e
|
|
428
424
|
|
|
429
425
|
def _analyze_maybe_overflow_tensor(self, tensor_json):
|
|
430
|
-
|
|
426
|
+
tensor_stat_index = tensor_json.get(Const.TENSOR_STAT_INDEX)
|
|
427
|
+
if tensor_stat_index is None:
|
|
428
|
+
logger.warning("tensor_stat_index does not exist in tensor_json.")
|
|
431
429
|
return
|
|
432
|
-
|
|
433
|
-
|
|
430
|
+
max_tensor = self.data_writer.get_buffer_values_max(tensor_stat_index)
|
|
431
|
+
min_tensor = self.data_writer.get_buffer_values_min(tensor_stat_index)
|
|
432
|
+
|
|
433
|
+
if max_tensor is None or min_tensor is None:
|
|
434
|
+
return
|
|
435
|
+
|
|
436
|
+
if torch.isinf(max_tensor) or torch.isnan(max_tensor):
|
|
437
|
+
self.has_overflow = True
|
|
438
|
+
return
|
|
439
|
+
|
|
440
|
+
if torch.isinf(min_tensor) or torch.isnan(min_tensor):
|
|
441
|
+
self.has_overflow = True
|
|
434
442
|
|
|
435
443
|
def _analyze_tensor(self, tensor, suffix):
|
|
436
444
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|