mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import hashlib
|
|
17
16
|
import zlib
|
|
18
17
|
from dataclasses import asdict
|
|
19
18
|
from typing import List
|
|
@@ -24,14 +23,15 @@ from torch import distributed as dist
|
|
|
24
23
|
from torch.distributed.distributed_c10d import _get_default_group
|
|
25
24
|
|
|
26
25
|
from msprobe.core.common.const import Const
|
|
26
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
27
27
|
from msprobe.core.common.file_utils import path_len_exceeds_limit
|
|
28
28
|
from msprobe.core.common.log import logger
|
|
29
29
|
from msprobe.core.common.utils import convert_tuple
|
|
30
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
30
31
|
from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
|
|
31
32
|
ModuleForwardInputsOutputs, TensorStatInfo
|
|
32
|
-
from msprobe.pytorch.common.utils import save_pt,
|
|
33
|
+
from msprobe.pytorch.common.utils import Const as PtConst, save_pt, is_hifloat8_tensor, is_float8_tensor
|
|
33
34
|
from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
|
|
34
|
-
from msprobe.core.common.utils import recursion_depth_decorator
|
|
35
35
|
|
|
36
36
|
is_gpu = False
|
|
37
37
|
try:
|
|
@@ -78,14 +78,16 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
78
78
|
def analyze_device_in_kwargs(element):
|
|
79
79
|
single_arg = {}
|
|
80
80
|
single_arg.update({'type': "torch.device"})
|
|
81
|
-
if
|
|
81
|
+
if isinstance(element, (int, str)):
|
|
82
|
+
single_arg.update({"value": element})
|
|
83
|
+
elif isinstance(element, torch.device):
|
|
82
84
|
if hasattr(element, "index"):
|
|
83
85
|
device_value = element.type + ":" + str(element.index)
|
|
84
86
|
else:
|
|
85
87
|
device_value = element.type
|
|
86
88
|
single_arg.update({"value": device_value})
|
|
87
89
|
else:
|
|
88
|
-
|
|
90
|
+
logger.debug(f"Device type {type(element)} is not supported.")
|
|
89
91
|
return single_arg
|
|
90
92
|
|
|
91
93
|
@staticmethod
|
|
@@ -99,19 +101,17 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
99
101
|
logger.warning("Async dump do not support complex data!")
|
|
100
102
|
return tensor_stat
|
|
101
103
|
elif data.dtype == torch.bool:
|
|
102
|
-
tensor_stat.
|
|
103
|
-
|
|
104
|
+
tensor_stat.max = torch.any(data)
|
|
105
|
+
tensor_stat.min = torch.all(data)
|
|
104
106
|
elif not data.shape:
|
|
105
|
-
tensor_stat.
|
|
107
|
+
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
|
|
106
108
|
else:
|
|
107
|
-
if
|
|
109
|
+
if data.dtype == torch.float64 or not data.is_floating_point():
|
|
108
110
|
data = data.float()
|
|
109
|
-
tensor_stat.
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
torch.norm(data)
|
|
114
|
-
]))
|
|
111
|
+
tensor_stat.max = torch.max(data)
|
|
112
|
+
tensor_stat.min = torch.min(data)
|
|
113
|
+
tensor_stat.mean = torch.mean(data)
|
|
114
|
+
tensor_stat.norm = torch.norm(data)
|
|
115
115
|
return tensor_stat
|
|
116
116
|
|
|
117
117
|
@staticmethod
|
|
@@ -124,17 +124,17 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
124
124
|
tensor_stat.min = np.min(data_abs).item()
|
|
125
125
|
tensor_stat.mean = np.mean(data_abs).item()
|
|
126
126
|
elif data.dtype == torch.bool:
|
|
127
|
-
tensor_stat.max = torch.any(data)
|
|
128
|
-
tensor_stat.min = torch.all(data)
|
|
127
|
+
tensor_stat.max = torch.any(data)
|
|
128
|
+
tensor_stat.min = torch.all(data)
|
|
129
129
|
elif not data.shape:
|
|
130
|
-
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
|
|
130
|
+
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
|
|
131
131
|
else:
|
|
132
|
-
if
|
|
132
|
+
if data.dtype == torch.float64 or not data.is_floating_point():
|
|
133
133
|
data = data.float()
|
|
134
|
-
tensor_stat.max = torch.max(data)
|
|
135
|
-
tensor_stat.min = torch.min(data)
|
|
136
|
-
tensor_stat.mean = torch.mean(data)
|
|
137
|
-
tensor_stat.norm = torch.norm(data)
|
|
134
|
+
tensor_stat.max = torch.max(data)
|
|
135
|
+
tensor_stat.min = torch.min(data)
|
|
136
|
+
tensor_stat.mean = torch.mean(data)
|
|
137
|
+
tensor_stat.norm = torch.norm(data)
|
|
138
138
|
return tensor_stat
|
|
139
139
|
|
|
140
140
|
@staticmethod
|
|
@@ -143,7 +143,7 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
143
143
|
if data.is_meta:
|
|
144
144
|
return tensor_stat
|
|
145
145
|
data_clone = data.detach()
|
|
146
|
-
if data_clone.numel()
|
|
146
|
+
if not data_clone.numel() or not data_clone.data_ptr():
|
|
147
147
|
return tensor_stat
|
|
148
148
|
else:
|
|
149
149
|
if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump:
|
|
@@ -171,12 +171,8 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
171
171
|
@staticmethod
|
|
172
172
|
def process_group_hash(arg):
|
|
173
173
|
group_ranks = dist.get_process_group_ranks(arg)
|
|
174
|
-
group_ranks_hash =
|
|
175
|
-
return group_ranks_hash
|
|
176
|
-
|
|
177
|
-
@staticmethod
|
|
178
|
-
def is_distributed_op(module):
|
|
179
|
-
return getattr(module, "op_is_distributed", False)
|
|
174
|
+
group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8'))
|
|
175
|
+
return f"{group_ranks_hash:08x}"
|
|
180
176
|
|
|
181
177
|
@staticmethod
|
|
182
178
|
def is_hookable_element(element):
|
|
@@ -214,43 +210,52 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
214
210
|
logger.warning(f"Failed to get value of torch.distributed.ReduceOp with error info: {e}.")
|
|
215
211
|
return {"type": "torch.distributed.ReduceOp", "value": op_type}
|
|
216
212
|
|
|
213
|
+
@staticmethod
|
|
214
|
+
def _cast_to_float_if_fp8(tensor):
|
|
215
|
+
dtype = str(tensor.dtype)
|
|
216
|
+
if is_float8_tensor(tensor):
|
|
217
|
+
dtype = PtConst.HIFLOAT8_TYPE if is_hifloat8_tensor(tensor) else dtype
|
|
218
|
+
logger.debug(
|
|
219
|
+
f"The {dtype} tensor analyzing/saving is unsupported in dump function."
|
|
220
|
+
f"Casting to float for processing."
|
|
221
|
+
)
|
|
222
|
+
tensor = tensor.float()
|
|
223
|
+
return tensor, dtype
|
|
224
|
+
|
|
217
225
|
@classmethod
|
|
218
226
|
def get_special_types(cls):
|
|
219
227
|
return super().get_special_types() + cls.pytorch_special_type
|
|
220
228
|
|
|
229
|
+
def dump_async_data(self):
|
|
230
|
+
for file_path, tensor in self._async_dump_cache.items():
|
|
231
|
+
save_pt(tensor.contiguous(), file_path)
|
|
232
|
+
self._async_dump_cache.clear()
|
|
233
|
+
|
|
221
234
|
def analyze_single_element(self, element, suffix_stack):
|
|
222
235
|
if suffix_stack and suffix_stack[-1] in self.torch_object_key:
|
|
223
236
|
return self.torch_object_key[suffix_stack[-1]](element)
|
|
224
|
-
if isinstance(element, torch.Size):
|
|
225
|
-
return self._analyze_torch_size(element)
|
|
226
|
-
if isinstance(element, torch.memory_format):
|
|
227
|
-
return self._analyze_memory_format(element)
|
|
228
|
-
if isinstance(element, dist.ProcessGroup):
|
|
229
|
-
return self._analyze_process_group(element)
|
|
230
|
-
if isinstance(element, dist.P2POp):
|
|
231
|
-
return self._analyze_p2pop(element)
|
|
232
|
-
if isinstance(element, dist.ReduceOp):
|
|
233
|
-
return self._analyze_reduce_op(element)
|
|
234
|
-
converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
|
|
235
|
-
if converted_numpy is not element:
|
|
236
|
-
return {"type": numpy_type, "value": converted_numpy}
|
|
237
|
-
if isinstance(element, torch.Tensor):
|
|
238
|
-
return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
|
|
239
|
-
if isinstance(element, np.ndarray):
|
|
240
|
-
return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
|
|
241
|
-
if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
|
|
242
|
-
return self._analyze_builtin(element)
|
|
243
|
-
return {}
|
|
244
237
|
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
238
|
+
suffix_str = Const.SEP.join(str(s) for s in suffix_stack)
|
|
239
|
+
type_analyzer = [
|
|
240
|
+
(PytorchDataProcessor.builtin_type, self._analyze_builtin),
|
|
241
|
+
(torch.Size, self._analyze_torch_size),
|
|
242
|
+
(torch.Tensor, lambda e: self._analyze_tensor(e, suffix_str)),
|
|
243
|
+
(torch.memory_format, self._analyze_memory_format),
|
|
244
|
+
(dist.ProcessGroup, self._analyze_process_group),
|
|
245
|
+
(dist.P2POp, lambda e: self._analyze_p2pop(e, suffix_str)),
|
|
246
|
+
(dist.ReduceOp, self._analyze_reduce_op),
|
|
247
|
+
(PytorchDataProcessor.np_type[:-1], self._analyze_numpy),
|
|
248
|
+
(np.ndarray, lambda e: self._analyze_ndarray(e, suffix_str)),
|
|
249
|
+
]
|
|
250
|
+
for type_key, analyze_fn in type_analyzer:
|
|
251
|
+
if isinstance(element, type_key):
|
|
252
|
+
return analyze_fn(element)
|
|
253
|
+
return {}
|
|
249
254
|
|
|
250
|
-
def _analyze_p2pop(self, arg):
|
|
255
|
+
def _analyze_p2pop(self, arg, suffix):
|
|
251
256
|
p2pop_info = {"class_type": "torch.distributed.P2POp"}
|
|
252
257
|
try:
|
|
253
|
-
tensor_info = self._analyze_tensor(arg.tensor,
|
|
258
|
+
tensor_info = self._analyze_tensor(arg.tensor, suffix)
|
|
254
259
|
p2pop_info.update({"tensor": tensor_info})
|
|
255
260
|
p2pop_info.update({"op": arg.op.__name__})
|
|
256
261
|
p2pop_info.update({"peer": arg.peer})
|
|
@@ -263,63 +268,71 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
263
268
|
return p2pop_info
|
|
264
269
|
|
|
265
270
|
def _analyze_tensor(self, tensor, suffix):
|
|
271
|
+
tensor, dtype = self._cast_to_float_if_fp8(tensor)
|
|
266
272
|
tensor_stat = self.get_stat_info(tensor, self.config.async_dump)
|
|
267
273
|
tensor_json = {}
|
|
268
274
|
tensor_json.update({'type': 'torch.Tensor'})
|
|
269
|
-
tensor_json.update({'dtype':
|
|
275
|
+
tensor_json.update({'dtype': dtype})
|
|
270
276
|
tensor_json.update({"shape": tensor.shape})
|
|
271
|
-
if tensor_stat.stack_tensor_stat is None:
|
|
272
|
-
tensor_json.update({"Max": tensor_stat.max})
|
|
273
|
-
tensor_json.update({"Min": tensor_stat.min})
|
|
274
|
-
tensor_json.update({"Mean": tensor_stat.mean})
|
|
275
|
-
tensor_json.update({"Norm": tensor_stat.norm})
|
|
276
|
-
tensor_json.update({"requires_grad": tensor.requires_grad})
|
|
277
|
-
if tensor_stat.max is not None:
|
|
278
|
-
if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
|
|
279
|
-
tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
|
|
280
|
-
if tensor_stat.min is not None:
|
|
281
|
-
if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
|
|
282
|
-
tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
|
|
283
277
|
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
278
|
+
stat_values = [
|
|
279
|
+
tensor_stat.max,
|
|
280
|
+
tensor_stat.min,
|
|
281
|
+
tensor_stat.mean,
|
|
282
|
+
tensor_stat.norm
|
|
283
|
+
]
|
|
284
|
+
placeholder_index = self.data_writer.append_stat_to_buffer(stat_values)
|
|
285
|
+
|
|
286
|
+
tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index})
|
|
287
|
+
tensor_json.update({"requires_grad": tensor.requires_grad})
|
|
287
288
|
|
|
288
289
|
if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
|
|
289
290
|
tensor_md5 = self.get_md5_for_tensor(tensor)
|
|
290
291
|
tensor_json.update({Const.MD5: tensor_md5})
|
|
291
292
|
return tensor_json
|
|
292
293
|
|
|
293
|
-
|
|
294
|
-
class StatisticsDataProcessor(PytorchDataProcessor):
|
|
295
|
-
pass
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
class TensorDataProcessor(PytorchDataProcessor):
|
|
299
|
-
def dump_async_data(self):
|
|
300
|
-
for file_path, tensor in self._async_dump_cache.items():
|
|
301
|
-
save_pt(tensor.contiguous(), file_path)
|
|
302
|
-
self._async_dump_cache.clear()
|
|
303
|
-
|
|
304
|
-
def _analyze_tensor(self, tensor, suffix):
|
|
294
|
+
def _analyze_and_save_tensor(self, tensor, suffix):
|
|
305
295
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
306
|
-
single_arg =
|
|
296
|
+
single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix)
|
|
307
297
|
single_arg.update({"data_name": dump_data_name})
|
|
298
|
+
tensor, _ = self._cast_to_float_if_fp8(tensor)
|
|
308
299
|
if self.config.async_dump:
|
|
309
300
|
self._async_dump_cache[file_path] = tensor.clone().detach()
|
|
310
301
|
else:
|
|
311
302
|
saved_tensor = tensor.clone().contiguous().detach()
|
|
312
303
|
save_pt(saved_tensor, file_path)
|
|
313
304
|
return single_arg
|
|
314
|
-
|
|
315
|
-
def
|
|
305
|
+
|
|
306
|
+
def _analyze_and_save_ndarray(self, ndarray, suffix):
|
|
316
307
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
317
308
|
save_pt(torch.tensor(ndarray), file_path)
|
|
318
|
-
ndarray_json =
|
|
309
|
+
ndarray_json = PytorchDataProcessor._analyze_ndarray(ndarray, suffix)
|
|
319
310
|
ndarray_json.update({"data_name": dump_data_name})
|
|
320
311
|
return ndarray_json
|
|
321
312
|
|
|
322
313
|
|
|
314
|
+
class StatisticsDataProcessor(PytorchDataProcessor):
|
|
315
|
+
def _analyze_tensor(self, tensor, suffix):
|
|
316
|
+
if any(item in self.current_api_or_module_name for item in self.config.tensor_list):
|
|
317
|
+
return self._analyze_and_save_tensor(tensor, suffix)
|
|
318
|
+
else:
|
|
319
|
+
return super()._analyze_tensor(tensor, suffix)
|
|
320
|
+
|
|
321
|
+
def _analyze_ndarray(self, ndarray, suffix):
|
|
322
|
+
if any(item in self.current_api_or_module_name for item in self.config.tensor_list):
|
|
323
|
+
return self._analyze_and_save_ndarray(ndarray, suffix)
|
|
324
|
+
else:
|
|
325
|
+
return super()._analyze_ndarray(ndarray, suffix)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
class TensorDataProcessor(PytorchDataProcessor):
|
|
329
|
+
def _analyze_tensor(self, tensor, suffix):
|
|
330
|
+
return self._analyze_and_save_tensor(tensor, suffix)
|
|
331
|
+
|
|
332
|
+
def _analyze_ndarray(self, ndarray, suffix):
|
|
333
|
+
return self._analyze_and_save_ndarray(ndarray, suffix)
|
|
334
|
+
|
|
335
|
+
|
|
323
336
|
class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
324
337
|
__slots__ = ["cached_tensors_and_file_paths"]
|
|
325
338
|
|
|
@@ -383,7 +396,8 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
383
396
|
self._analyze_maybe_overflow_flag()
|
|
384
397
|
if self.has_overflow:
|
|
385
398
|
for file_path, tensor in self.cached_tensors_and_file_paths.items():
|
|
386
|
-
|
|
399
|
+
tensor, _ = self._cast_to_float_if_fp8(tensor)
|
|
400
|
+
save_pt(tensor.clone().contiguous().detach(), file_path)
|
|
387
401
|
self.real_overflow_nums += 1
|
|
388
402
|
if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums:
|
|
389
403
|
logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, "
|
|
@@ -409,10 +423,22 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
409
423
|
raise RuntimeError(f"overflow check failed") from e
|
|
410
424
|
|
|
411
425
|
def _analyze_maybe_overflow_tensor(self, tensor_json):
|
|
412
|
-
|
|
426
|
+
tensor_stat_index = tensor_json.get(Const.TENSOR_STAT_INDEX)
|
|
427
|
+
if tensor_stat_index is None:
|
|
428
|
+
logger.warning("tensor_stat_index does not exist in tensor_json.")
|
|
429
|
+
return
|
|
430
|
+
max_tensor = self.data_writer.get_buffer_values_max(tensor_stat_index)
|
|
431
|
+
min_tensor = self.data_writer.get_buffer_values_min(tensor_stat_index)
|
|
432
|
+
|
|
433
|
+
if max_tensor is None or min_tensor is None:
|
|
434
|
+
return
|
|
435
|
+
|
|
436
|
+
if torch.isinf(max_tensor) or torch.isnan(max_tensor):
|
|
437
|
+
self.has_overflow = True
|
|
413
438
|
return
|
|
414
|
-
|
|
415
|
-
|
|
439
|
+
|
|
440
|
+
if torch.isinf(min_tensor) or torch.isnan(min_tensor):
|
|
441
|
+
self.has_overflow = True
|
|
416
442
|
|
|
417
443
|
def _analyze_tensor(self, tensor, suffix):
|
|
418
444
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
@@ -508,11 +534,13 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
|
|
|
508
534
|
return
|
|
509
535
|
|
|
510
536
|
if self.config.is_backward_kernel_dump:
|
|
511
|
-
self.forward_args = self.clone_and_detach_tensor(module_input_output.args)
|
|
512
|
-
self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs)
|
|
513
537
|
try:
|
|
538
|
+
self.forward_args = self.clone_and_detach_tensor(module_input_output.args)
|
|
539
|
+
self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs)
|
|
514
540
|
output = module.forward(*self.forward_args, **self.forward_kwargs)
|
|
515
|
-
except Exception:
|
|
541
|
+
except Exception as e:
|
|
542
|
+
if isinstance(e, MsprobeException):
|
|
543
|
+
logger.warning(str(e))
|
|
516
544
|
self._print_unsupported_log(name)
|
|
517
545
|
self.enable_kernel_dump = False
|
|
518
546
|
return
|
|
@@ -554,9 +582,17 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
|
|
|
554
582
|
self.stop_kernel_dump()
|
|
555
583
|
logger.info(f"The kernel data of {name} is dumped successfully.")
|
|
556
584
|
|
|
557
|
-
@recursion_depth_decorator(
|
|
585
|
+
@recursion_depth_decorator(
|
|
586
|
+
"KernelDump: KernelDumpDataProcessor.clone_and_detach_tensor",
|
|
587
|
+
max_depth=Const.DUMP_MAX_DEPTH
|
|
588
|
+
)
|
|
558
589
|
def clone_and_detach_tensor(self, input_params):
|
|
559
590
|
if isinstance(input_params, torch.Tensor):
|
|
591
|
+
if is_float8_tensor(input_params):
|
|
592
|
+
raise MsprobeException(
|
|
593
|
+
MsprobeException.UNSUPPORTED_TYPE_ERROR,
|
|
594
|
+
f"L2 backward dump does not support float8 type."
|
|
595
|
+
)
|
|
560
596
|
if input_params.requires_grad:
|
|
561
597
|
return input_params.clone().detach().requires_grad_()
|
|
562
598
|
return input_params.clone()
|
|
@@ -571,6 +607,8 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
|
|
|
571
607
|
|
|
572
608
|
def analyze_single_element(self, element, suffix_stack):
|
|
573
609
|
if isinstance(element, torch.Tensor):
|
|
610
|
+
if is_float8_tensor(element):
|
|
611
|
+
return {}
|
|
574
612
|
if not self.is_found_output_tensor:
|
|
575
613
|
if element.requires_grad:
|
|
576
614
|
self.forward_output_tensor = element
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -16,12 +16,14 @@
|
|
|
16
16
|
import csv
|
|
17
17
|
import os
|
|
18
18
|
import copy
|
|
19
|
-
import
|
|
19
|
+
import threading
|
|
20
20
|
|
|
21
21
|
from msprobe.core.common.const import Const, FileCheckConst
|
|
22
22
|
from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json
|
|
23
23
|
from msprobe.core.common.log import logger
|
|
24
|
-
from msprobe.core.common.
|
|
24
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
25
|
+
|
|
26
|
+
lock = threading.Lock()
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
class DataWriter:
|
|
@@ -34,10 +36,12 @@ class DataWriter:
|
|
|
34
36
|
self.dump_tensor_data_dir = None
|
|
35
37
|
self.debug_file_path = None
|
|
36
38
|
self.flush_size = 1000
|
|
39
|
+
self.larger_flush_size = 20000
|
|
37
40
|
self.cache_data = {}
|
|
38
41
|
self.cache_stack = {}
|
|
39
42
|
self.cache_construct = {}
|
|
40
43
|
self.cache_debug = {}
|
|
44
|
+
self.stat_stack_list = []
|
|
41
45
|
|
|
42
46
|
@staticmethod
|
|
43
47
|
def write_data_to_csv(result: list, result_header: tuple, file_path: str):
|
|
@@ -54,13 +58,54 @@ class DataWriter:
|
|
|
54
58
|
if is_new_file:
|
|
55
59
|
change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
56
60
|
|
|
61
|
+
@recursion_depth_decorator("JsonWriter: DataWriter._replace_stat_placeholders")
|
|
62
|
+
def _replace_stat_placeholders(self, data, stat_result):
|
|
63
|
+
if isinstance(data, dict):
|
|
64
|
+
keys = list(data.keys()) # 获取当前所有键
|
|
65
|
+
for key in keys: # 递归所有变量
|
|
66
|
+
value = data[key]
|
|
67
|
+
if key == Const.TENSOR_STAT_INDEX and isinstance(value, int):
|
|
68
|
+
if value >= 0:
|
|
69
|
+
idx = value
|
|
70
|
+
else:
|
|
71
|
+
return
|
|
72
|
+
stat_values = stat_result[idx] if idx < len(stat_result) else [None] * 4
|
|
73
|
+
|
|
74
|
+
new_entries = {
|
|
75
|
+
Const.TYPE: data["type"],
|
|
76
|
+
Const.DTYPE: data["dtype"],
|
|
77
|
+
Const.SHAPE: data["shape"],
|
|
78
|
+
Const.MAX: stat_values[0],
|
|
79
|
+
Const.MIN: stat_values[1],
|
|
80
|
+
Const.MEAN: stat_values[2],
|
|
81
|
+
Const.NORM: stat_values[3],
|
|
82
|
+
}
|
|
83
|
+
del data[key]
|
|
84
|
+
|
|
85
|
+
# 重构字典顺序
|
|
86
|
+
updated_dict = {}
|
|
87
|
+
# 通过插入排序后字段保证字段写入json的有序
|
|
88
|
+
updated_dict.update(new_entries)
|
|
89
|
+
# 遍历原字典其他字段(排除已删除的tensor_stat_index)
|
|
90
|
+
for k in data:
|
|
91
|
+
if k not in new_entries:
|
|
92
|
+
updated_dict[k] = data[k]
|
|
93
|
+
data.clear()
|
|
94
|
+
data.update(updated_dict)
|
|
95
|
+
else:
|
|
96
|
+
self._replace_stat_placeholders(value, stat_result)
|
|
97
|
+
elif isinstance(data, (list, tuple)):
|
|
98
|
+
for item in data:
|
|
99
|
+
self._replace_stat_placeholders(item, stat_result)
|
|
100
|
+
|
|
57
101
|
def reset_cache(self):
|
|
58
102
|
self.cache_data = {}
|
|
59
103
|
self.cache_stack = {}
|
|
60
104
|
self.cache_construct = {}
|
|
105
|
+
self.cache_debug = {}
|
|
61
106
|
|
|
62
107
|
def initialize_json_file(self, **kwargs):
|
|
63
|
-
if
|
|
108
|
+
if kwargs["level"] == Const.LEVEL_DEBUG and not self.cache_debug:
|
|
64
109
|
# debug level case only create debug.json
|
|
65
110
|
debug_dict = copy.deepcopy(kwargs)
|
|
66
111
|
debug_dict.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
|
|
@@ -86,39 +131,59 @@ class DataWriter:
|
|
|
86
131
|
|
|
87
132
|
def flush_data_periodically(self):
|
|
88
133
|
dump_data = self.cache_data.get(Const.DATA)
|
|
89
|
-
if dump_data and isinstance(dump_data, dict) and len(dump_data) % self.flush_size == 0:
|
|
90
|
-
self.write_json()
|
|
91
134
|
|
|
92
|
-
|
|
93
|
-
if not isinstance(new_data, dict) or len(new_data.keys()) != 1:
|
|
94
|
-
logger.warning(f"The data info({new_data}) should be a dict with only one outer key.")
|
|
95
|
-
return
|
|
96
|
-
dump_data = self.cache_data.get(Const.DATA)
|
|
97
|
-
if not isinstance(dump_data, dict):
|
|
98
|
-
logger.warning(f"The dump data({dump_data}) should be a dict.")
|
|
135
|
+
if not dump_data or not isinstance(dump_data, dict):
|
|
99
136
|
return
|
|
100
137
|
|
|
101
|
-
|
|
102
|
-
if key in dump_data:
|
|
103
|
-
dump_data.get(key).update(new_data.get(key))
|
|
104
|
-
else:
|
|
105
|
-
dump_data.update(new_data)
|
|
138
|
+
length = len(dump_data)
|
|
106
139
|
|
|
107
|
-
|
|
108
|
-
|
|
140
|
+
threshold = self.flush_size if length < self.larger_flush_size else self.larger_flush_size
|
|
141
|
+
|
|
142
|
+
if length % threshold == 0:
|
|
143
|
+
self.write_json()
|
|
144
|
+
|
|
145
|
+
def update_data(self, new_data):
|
|
146
|
+
with lock:
|
|
147
|
+
if not isinstance(new_data, dict) or len(new_data.keys()) != 1:
|
|
148
|
+
logger.warning(f"The data info({new_data}) should be a dict with only one outer key.")
|
|
149
|
+
return
|
|
150
|
+
dump_data = self.cache_data.get(Const.DATA)
|
|
151
|
+
if not isinstance(dump_data, dict):
|
|
152
|
+
logger.warning(f"The dump data({dump_data}) should be a dict.")
|
|
153
|
+
return
|
|
154
|
+
|
|
155
|
+
key = next(iter(new_data.keys()))
|
|
156
|
+
if key in dump_data:
|
|
157
|
+
dump_data.get(key).update(new_data.get(key))
|
|
158
|
+
else:
|
|
159
|
+
dump_data.update(new_data)
|
|
160
|
+
|
|
161
|
+
def update_stack(self, name, stack_data):
|
|
162
|
+
with lock:
|
|
163
|
+
api_list = self.cache_stack.get(stack_data)
|
|
164
|
+
if api_list is None:
|
|
165
|
+
self.cache_stack.update({stack_data: [name]})
|
|
166
|
+
else:
|
|
167
|
+
api_list.append(name)
|
|
109
168
|
|
|
110
169
|
def update_construct(self, new_data):
|
|
111
|
-
|
|
170
|
+
with lock:
|
|
171
|
+
self.cache_construct.update(new_data)
|
|
112
172
|
|
|
113
173
|
def update_debug(self, new_data):
|
|
114
|
-
|
|
174
|
+
with lock:
|
|
175
|
+
self.cache_debug['data'].update(new_data)
|
|
115
176
|
|
|
116
177
|
def write_data_json(self, file_path):
|
|
117
178
|
logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
|
|
118
179
|
save_json(file_path, self.cache_data, indent=1)
|
|
119
180
|
|
|
120
181
|
def write_stack_info_json(self, file_path):
|
|
121
|
-
|
|
182
|
+
num, new_cache_stack = 0, {}
|
|
183
|
+
for key, value in self.cache_stack.items():
|
|
184
|
+
new_cache_stack[num] = [value, key]
|
|
185
|
+
num += 1
|
|
186
|
+
save_json(file_path, new_cache_stack, indent=1)
|
|
122
187
|
|
|
123
188
|
def write_construct_info_json(self, file_path):
|
|
124
189
|
save_json(file_path, self.cache_construct, indent=1)
|
|
@@ -126,38 +191,62 @@ class DataWriter:
|
|
|
126
191
|
def write_debug_info_json(self, file_path):
|
|
127
192
|
save_json(file_path, self.cache_debug, indent=1)
|
|
128
193
|
|
|
194
|
+
def append_stat_to_buffer(self, stat_vector):
|
|
195
|
+
"""
|
|
196
|
+
直接使用 Python list 存储 stat_vector,
|
|
197
|
+
将 stat_vector 存入 self.stat_stack_list 的方式
|
|
198
|
+
"""
|
|
199
|
+
self.stat_stack_list.append(stat_vector)
|
|
200
|
+
return len(self.stat_stack_list) - 1
|
|
201
|
+
|
|
202
|
+
def get_buffer_values_max(self, index):
|
|
203
|
+
if 0 <= index < len(self.stat_stack_list) and len(self.stat_stack_list[index]) >= 1:
|
|
204
|
+
return self.stat_stack_list[index][0]
|
|
205
|
+
else:
|
|
206
|
+
logger.warning(f"stat_stack_list[{index}] The internal data is incomplete,"
|
|
207
|
+
f" and the maximum value cannot be obtained.")
|
|
208
|
+
return None
|
|
209
|
+
|
|
210
|
+
def get_buffer_values_min(self, index):
|
|
211
|
+
if 0 <= index < len(self.stat_stack_list) and len(self.stat_stack_list[index]) >= 1:
|
|
212
|
+
return self.stat_stack_list[index][1]
|
|
213
|
+
else:
|
|
214
|
+
logger.warning(f"stat_stack_list[{index}] Internal data is incomplete"
|
|
215
|
+
f" and minimum values cannot be obtained.")
|
|
216
|
+
return None
|
|
217
|
+
|
|
218
|
+
def flush_stat_stack(self):
|
|
219
|
+
"""
|
|
220
|
+
在 flush 阶段,将所有存储的统计值从设备搬到 CPU,
|
|
221
|
+
这里返回一个列表,每个元素是 [Max, Min, Mean, Norm] 的数值列表
|
|
222
|
+
"""
|
|
223
|
+
if not self.stat_stack_list:
|
|
224
|
+
return []
|
|
225
|
+
result = [
|
|
226
|
+
[
|
|
227
|
+
x.item() if hasattr(x, "item") else x
|
|
228
|
+
for x in stat_values
|
|
229
|
+
]
|
|
230
|
+
for stat_values in self.stat_stack_list
|
|
231
|
+
]
|
|
232
|
+
self.stat_stack_list = []
|
|
233
|
+
return result
|
|
234
|
+
|
|
129
235
|
def write_json(self):
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
if "tensor_stat" in data.keys():
|
|
148
|
-
tensor_stat = data["tensor_stat"]
|
|
149
|
-
if len(tensor_stat) != Const.TENSOR_STAT_LEN or len(tensor_stat[0]) != len(tensor_stat[1]):
|
|
150
|
-
logger.warning("Some bad data in async dump")
|
|
151
|
-
else:
|
|
152
|
-
tensor_stat_index, tensor_stat_data = tensor_stat[0], tensor_stat[1]
|
|
153
|
-
if hasattr(tensor_stat_data, "device") and tensor_stat_data.device != Const.CPU_LOWERCASE:
|
|
154
|
-
tensor_stat_data = tensor_stat_data.cpu()
|
|
155
|
-
for index, stat in zip(tensor_stat_index, tensor_stat_data):
|
|
156
|
-
data.update({index: stat.item()})
|
|
157
|
-
del data["tensor_stat"]
|
|
158
|
-
else:
|
|
159
|
-
for key in data.keys():
|
|
160
|
-
self.process_stat_data_recursive(data[key], depth + 1)
|
|
161
|
-
elif isinstance(data, (list, tuple)):
|
|
162
|
-
for i in data:
|
|
163
|
-
self.process_stat_data_recursive(i, depth + 1)
|
|
236
|
+
with lock:
|
|
237
|
+
# 在写 JSON 前,统一获取统计值
|
|
238
|
+
stat_result = self.flush_stat_stack()
|
|
239
|
+
# 遍历 cache_data,将占位符替换为最终统计值
|
|
240
|
+
if stat_result:
|
|
241
|
+
self._replace_stat_placeholders(self.cache_data, stat_result)
|
|
242
|
+
if self.cache_debug:
|
|
243
|
+
self._replace_stat_placeholders(self.cache_debug, stat_result)
|
|
244
|
+
if self.cache_data:
|
|
245
|
+
self.write_data_json(self.dump_file_path)
|
|
246
|
+
if self.cache_stack:
|
|
247
|
+
self.write_stack_info_json(self.stack_file_path)
|
|
248
|
+
if self.cache_construct:
|
|
249
|
+
self.write_construct_info_json(self.construct_file_path)
|
|
250
|
+
if self.cache_debug:
|
|
251
|
+
self.write_debug_info_json(self.debug_file_path)
|
|
252
|
+
|