mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,17 +13,17 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import copy
|
|
16
17
|
import inspect
|
|
17
18
|
import os
|
|
18
19
|
from dataclasses import dataclass, is_dataclass
|
|
19
|
-
from typing import Tuple, Dict, Optional, Any
|
|
20
20
|
from functools import partial
|
|
21
|
-
import
|
|
22
|
-
from typing import Union
|
|
21
|
+
from typing import Tuple, Dict, Optional, Any, Union
|
|
23
22
|
|
|
24
23
|
import numpy as np
|
|
25
24
|
|
|
26
25
|
from msprobe.core.common.const import Const
|
|
26
|
+
from msprobe.core.common.file_utils import save_npy
|
|
27
27
|
from msprobe.core.common.log import logger
|
|
28
28
|
from msprobe.core.common.utils import convert_tuple, CompareException
|
|
29
29
|
|
|
@@ -79,21 +79,17 @@ class ModuleBackwardOutputs:
|
|
|
79
79
|
|
|
80
80
|
|
|
81
81
|
class TensorStatInfo:
|
|
82
|
-
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None
|
|
82
|
+
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
|
|
83
83
|
self.max = max_val
|
|
84
84
|
self.min = min_val
|
|
85
85
|
self.mean = mean_val
|
|
86
86
|
self.norm = norm_val
|
|
87
|
-
self.stack_tensor_stat = stack_tensor_stat
|
|
88
87
|
|
|
89
88
|
|
|
90
89
|
class BaseDataProcessor:
|
|
91
90
|
_recursive_key_stack = []
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
bool, int, float, str, slice,
|
|
95
|
-
type(Ellipsis)
|
|
96
|
-
)
|
|
91
|
+
builtin_type = (bool, int, float, str, slice, type(Ellipsis))
|
|
92
|
+
np_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray)
|
|
97
93
|
|
|
98
94
|
def __init__(self, config, data_writer):
|
|
99
95
|
self.data_writer = data_writer
|
|
@@ -120,7 +116,10 @@ class BaseDataProcessor:
|
|
|
120
116
|
@staticmethod
|
|
121
117
|
def analyze_api_call_stack(name):
|
|
122
118
|
try:
|
|
123
|
-
|
|
119
|
+
if name.startswith("Primitive"):
|
|
120
|
+
api_stack = inspect.stack()[4:]
|
|
121
|
+
else:
|
|
122
|
+
api_stack = inspect.stack()[5:]
|
|
124
123
|
except Exception as e:
|
|
125
124
|
logger.warning(f"The call stack of <{name}> failed to retrieve, {e}.")
|
|
126
125
|
api_stack = None
|
|
@@ -129,12 +128,14 @@ class BaseDataProcessor:
|
|
|
129
128
|
for (_, path, line, func, code, _) in api_stack:
|
|
130
129
|
if not code:
|
|
131
130
|
continue
|
|
131
|
+
if any(filter_path in path for filter_path in Const.STACK_FILTER_KEYWORDS) and \
|
|
132
|
+
Const.CALL_STACK_FLAG not in path:
|
|
133
|
+
continue
|
|
132
134
|
stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()}"
|
|
133
135
|
stack_str.append(stack_line)
|
|
134
136
|
else:
|
|
135
137
|
stack_str.append(Const.WITHOUT_CALL_STACK)
|
|
136
|
-
|
|
137
|
-
return stack_info_struct
|
|
138
|
+
return tuple(stack_str)
|
|
138
139
|
|
|
139
140
|
@staticmethod
|
|
140
141
|
def transfer_type(data):
|
|
@@ -178,20 +179,8 @@ class BaseDataProcessor:
|
|
|
178
179
|
"invalid data_structure type or invalid index")
|
|
179
180
|
|
|
180
181
|
@staticmethod
|
|
181
|
-
def
|
|
182
|
-
|
|
183
|
-
np.integer: int,
|
|
184
|
-
np.floating: float,
|
|
185
|
-
np.bool_: bool,
|
|
186
|
-
np.complexfloating: complex,
|
|
187
|
-
np.str_: str,
|
|
188
|
-
np.byte: bytes,
|
|
189
|
-
np.unicode_: str
|
|
190
|
-
}
|
|
191
|
-
for numpy_type, builtin_type in type_mapping.items():
|
|
192
|
-
if isinstance(arg, numpy_type):
|
|
193
|
-
return builtin_type(arg), type(arg).__name__
|
|
194
|
-
return arg, ''
|
|
182
|
+
def is_distributed_op(module):
|
|
183
|
+
return getattr(module, "op_is_distributed", False)
|
|
195
184
|
|
|
196
185
|
@staticmethod
|
|
197
186
|
def _analyze_builtin(arg):
|
|
@@ -217,21 +206,40 @@ class BaseDataProcessor:
|
|
|
217
206
|
return single_arg
|
|
218
207
|
|
|
219
208
|
@staticmethod
|
|
220
|
-
def _analyze_numpy(
|
|
209
|
+
def _analyze_numpy(arg):
|
|
210
|
+
return {"type": type(arg).__name__, "value": arg.item()}
|
|
211
|
+
|
|
212
|
+
@staticmethod
|
|
213
|
+
def _analyze_ndarray(ndarray, _):
|
|
221
214
|
ndarray_json = {}
|
|
222
215
|
ndarray_json.update({'type': 'numpy.ndarray'})
|
|
223
216
|
ndarray_json.update({'dtype': str(ndarray.dtype)})
|
|
224
217
|
ndarray_json.update({'shape': ndarray.shape})
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
218
|
+
|
|
219
|
+
# 先初始化默认值
|
|
220
|
+
stats = {
|
|
221
|
+
"Max": None,
|
|
222
|
+
"Min": None,
|
|
223
|
+
"Mean": None,
|
|
224
|
+
"Norm": None
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
# 只有非空时才尝试计算
|
|
229
|
+
if ndarray.size > 0:
|
|
230
|
+
stats = {
|
|
231
|
+
"Max": np.max(ndarray).item(),
|
|
232
|
+
"Min": np.min(ndarray).item(),
|
|
233
|
+
"Mean": np.mean(ndarray).item(),
|
|
234
|
+
"Norm": np.linalg.norm(ndarray).item()
|
|
235
|
+
}
|
|
236
|
+
except Exception as e:
|
|
237
|
+
# 决定打印内容或切片
|
|
238
|
+
logger.warning(f"Error analyzing ndarray stats: {e}")
|
|
239
|
+
|
|
240
|
+
# 最后一次性更新
|
|
241
|
+
ndarray_json.update(stats)
|
|
242
|
+
|
|
235
243
|
return ndarray_json
|
|
236
244
|
|
|
237
245
|
@staticmethod
|
|
@@ -248,12 +256,12 @@ class BaseDataProcessor:
|
|
|
248
256
|
|
|
249
257
|
@classmethod
|
|
250
258
|
def get_special_types(cls):
|
|
251
|
-
return cls.
|
|
259
|
+
return cls.builtin_type + cls.np_type
|
|
252
260
|
|
|
253
261
|
@classmethod
|
|
254
262
|
def recursive_apply_transform(cls, args, transform, depth=0) -> Union[dict, list, None]:
|
|
255
|
-
if depth > Const.
|
|
256
|
-
logger.error(f"The maximum depth of recursive transform, {Const.
|
|
263
|
+
if depth > Const.DUMP_MAX_DEPTH:
|
|
264
|
+
logger.error(f"The maximum depth of recursive transform, {Const.DUMP_MAX_DEPTH} is reached.")
|
|
257
265
|
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
258
266
|
if isinstance(args, cls.get_special_types()):
|
|
259
267
|
arg_transform = transform(args, cls._recursive_key_stack)
|
|
@@ -303,6 +311,7 @@ class BaseDataProcessor:
|
|
|
303
311
|
|
|
304
312
|
def real_hook_fn(grad):
|
|
305
313
|
return wrap_hook_fn(grad)
|
|
314
|
+
|
|
306
315
|
element.register_hook(real_hook_fn)
|
|
307
316
|
|
|
308
317
|
def if_return_forward_new_output(self):
|
|
@@ -350,6 +359,8 @@ class BaseDataProcessor:
|
|
|
350
359
|
return api_info_struct
|
|
351
360
|
|
|
352
361
|
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
362
|
+
if self.is_distributed_op(module):
|
|
363
|
+
module_input_output.update_output_with_args_and_kwargs()
|
|
353
364
|
api_info_struct = {}
|
|
354
365
|
# check whether data_mode contains forward or input
|
|
355
366
|
if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
|
|
@@ -427,6 +438,7 @@ class BaseDataProcessor:
|
|
|
427
438
|
api_info_struct = {}
|
|
428
439
|
self.save_name = name + Const.SEP + param_name
|
|
429
440
|
data_info = self.analyze_element(grad)
|
|
441
|
+
self.save_name = None
|
|
430
442
|
grad_info_dict = {param_name: [data_info]}
|
|
431
443
|
api_info_struct[name] = grad_info_dict
|
|
432
444
|
return api_info_struct
|
|
@@ -435,10 +447,10 @@ class BaseDataProcessor:
|
|
|
435
447
|
file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
|
|
436
448
|
if self.save_name is not None:
|
|
437
449
|
dump_data_name = (self.save_name + file_format)
|
|
438
|
-
self.save_name = None
|
|
439
450
|
else:
|
|
440
|
-
|
|
441
|
-
|
|
451
|
+
suffix_with_seq = (Const.SEP + suffix) if suffix else ""
|
|
452
|
+
dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + suffix_with_seq +
|
|
453
|
+
file_format)
|
|
442
454
|
file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
|
|
443
455
|
return dump_data_name, file_path
|
|
444
456
|
|
|
@@ -447,23 +459,32 @@ class BaseDataProcessor:
|
|
|
447
459
|
|
|
448
460
|
def analyze_debug_forward(self, variable, name_with_count):
|
|
449
461
|
self.current_api_or_module_name = name_with_count
|
|
450
|
-
self.api_data_category = Const.
|
|
451
|
-
# these two attributes are used to construct tensor file name {name_with_count}.
|
|
462
|
+
self.api_data_category = Const.DEBUG
|
|
463
|
+
# these two attributes are used to construct tensor file name {name_with_count}.debug.{indexes}.npy/pt
|
|
452
464
|
data_info = self.analyze_element(variable)
|
|
453
465
|
return data_info
|
|
454
466
|
|
|
455
|
-
def analyze_debug_backward(self, variable,
|
|
467
|
+
def analyze_debug_backward(self, variable, grad_name_with_count_category, nested_data_structure):
|
|
456
468
|
def hook_fn(grad, indexes):
|
|
457
469
|
suffix = Const.SEP.join([str(index) for index in indexes])
|
|
458
|
-
|
|
470
|
+
suffix_with_sep = (Const.SEP + suffix) if suffix else ""
|
|
471
|
+
self.save_name = grad_name_with_count_category + suffix_with_sep
|
|
459
472
|
grad_data_info = self.analyze_element(grad)
|
|
460
473
|
self.save_name = None
|
|
461
|
-
full_index = [
|
|
474
|
+
full_index = [grad_name_with_count_category] + indexes
|
|
462
475
|
try:
|
|
463
476
|
self.set_value_into_nested_structure(nested_data_structure, full_index, grad_data_info)
|
|
464
477
|
except (ValueError, IndexError) as e:
|
|
465
|
-
logger.warning(f"error
|
|
466
|
-
f"skip current recording, detailed
|
|
478
|
+
logger.warning(f"error occurred while recording statistics of {grad_name_with_count_category} variable,"
|
|
479
|
+
f"skip current recording, detailed information: {e}")
|
|
467
480
|
return grad
|
|
481
|
+
|
|
468
482
|
wrap_register_hook_single_element = partial(self.register_hook_single_element, hook_fn=hook_fn)
|
|
469
|
-
self.recursive_apply_transform(variable, wrap_register_hook_single_element)
|
|
483
|
+
self.recursive_apply_transform(variable, wrap_register_hook_single_element)
|
|
484
|
+
|
|
485
|
+
def _analyze_and_save_ndarray(self, ndarray, suffix):
|
|
486
|
+
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
487
|
+
save_npy(ndarray, file_path)
|
|
488
|
+
ndarray_json = BaseDataProcessor._analyze_ndarray(ndarray, suffix)
|
|
489
|
+
ndarray_json.update({"data_name": dump_data_name})
|
|
490
|
+
return ndarray_json
|
|
@@ -17,16 +17,17 @@ import zlib
|
|
|
17
17
|
|
|
18
18
|
import mindspore as ms
|
|
19
19
|
from mindspore import mint, ops, hal
|
|
20
|
+
from mindspore.mint import distributed
|
|
20
21
|
from mindspore._c_expression.typing import Number
|
|
21
22
|
import numpy as np
|
|
22
23
|
|
|
23
24
|
from msprobe.core.common.const import Const
|
|
24
25
|
from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo,
|
|
25
26
|
ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs)
|
|
26
|
-
from msprobe.core.common.file_utils import path_len_exceeds_limit
|
|
27
|
+
from msprobe.core.common.file_utils import path_len_exceeds_limit
|
|
27
28
|
from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy
|
|
28
29
|
from msprobe.mindspore.common.log import logger
|
|
29
|
-
from msprobe.mindspore.dump.hook_cell.
|
|
30
|
+
from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
|
|
30
31
|
|
|
31
32
|
has_adump = True
|
|
32
33
|
try:
|
|
@@ -36,7 +37,7 @@ except ImportError:
|
|
|
36
37
|
|
|
37
38
|
|
|
38
39
|
class MindsporeDataProcessor(BaseDataProcessor):
|
|
39
|
-
mindspore_special_type = tuple([ms.Tensor, Number])
|
|
40
|
+
mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp])
|
|
40
41
|
|
|
41
42
|
def __init__(self, config, data_writer):
|
|
42
43
|
super().__init__(config, data_writer)
|
|
@@ -44,6 +45,7 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
44
45
|
"dtype": self.analyze_dtype_in_kwargs
|
|
45
46
|
}
|
|
46
47
|
self._async_dump_cache = {}
|
|
48
|
+
self.api_register = get_api_register()
|
|
47
49
|
|
|
48
50
|
@staticmethod
|
|
49
51
|
def get_md5_for_tensor(x):
|
|
@@ -64,7 +66,7 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
64
66
|
tensor_stat.max = np.max(data_np).item()
|
|
65
67
|
tensor_stat.min = np.min(data_np).item()
|
|
66
68
|
elif not data.shape:
|
|
67
|
-
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
|
|
69
|
+
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
|
|
68
70
|
elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
|
|
69
71
|
data_abs = np.abs(data.asnumpy())
|
|
70
72
|
tensor_stat.max = np.max(data_abs).item()
|
|
@@ -74,83 +76,98 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
74
76
|
else:
|
|
75
77
|
if not ops.is_floating_point(data) or data.dtype == ms.float64:
|
|
76
78
|
data = data.to(ms.float32)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm)
|
|
83
|
-
else:
|
|
84
|
-
get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
|
|
85
|
-
tensor_stat.max = get_max_value(data).item()
|
|
86
|
-
tensor_stat.min = get_min_value(data).item()
|
|
87
|
-
tensor_stat.mean = get_mean_value(data).item()
|
|
88
|
-
tensor_stat.norm = get_norm_value(data).item()
|
|
89
|
-
api_register.norm_inner_op_set_hook_func()
|
|
79
|
+
get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
|
|
80
|
+
tensor_stat.max = mint.max(data)
|
|
81
|
+
tensor_stat.min = mint.min(data)
|
|
82
|
+
tensor_stat.mean = mint.mean(data)
|
|
83
|
+
tensor_stat.norm = get_norm_value(data)
|
|
90
84
|
return tensor_stat
|
|
91
85
|
|
|
92
86
|
@staticmethod
|
|
93
87
|
def get_stat_info_async(data):
|
|
94
88
|
tensor_stat = TensorStatInfo()
|
|
95
|
-
|
|
96
|
-
|
|
89
|
+
if data.dtype == ms.bool_:
|
|
90
|
+
tensor_stat.max = mint.any(data)
|
|
91
|
+
tensor_stat.min = mint.all(data)
|
|
92
|
+
elif not data.shape:
|
|
93
|
+
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
|
|
94
|
+
elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
|
|
97
95
|
logger.warning("Async dump do not support complex data!")
|
|
98
96
|
return tensor_stat
|
|
99
|
-
elif data.dtype == ms.bool_:
|
|
100
|
-
tensor_stat.stack_tensor_stat = (["Max", "Min"], stack_method([data.any(), data.all()]))
|
|
101
|
-
elif not data.shape:
|
|
102
|
-
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method([data, data, data, data]))
|
|
103
97
|
else:
|
|
104
98
|
if not ops.is_floating_point(data) or data.dtype == ms.float64:
|
|
105
99
|
data = data.to(ms.float32)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm)
|
|
112
|
-
else:
|
|
113
|
-
get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
|
|
114
|
-
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method(
|
|
115
|
-
[get_max_value(data), get_min_value(data), get_mean_value(data), get_norm_value(data)]))
|
|
116
|
-
api_register.norm_inner_op_set_hook_func()
|
|
100
|
+
get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
|
|
101
|
+
tensor_stat.max = mint.max(data)
|
|
102
|
+
tensor_stat.min = mint.min(data)
|
|
103
|
+
tensor_stat.mean = mint.mean(data)
|
|
104
|
+
tensor_stat.norm = get_norm_value(data)
|
|
117
105
|
return tensor_stat
|
|
118
106
|
|
|
119
107
|
@staticmethod
|
|
120
108
|
def is_hookable_element(element):
|
|
121
109
|
return hasattr(element, "register_hook") and callable(element.register_hook)
|
|
122
110
|
|
|
111
|
+
@staticmethod
|
|
112
|
+
def process_group_hash(arg):
|
|
113
|
+
group_ranks = distributed.get_process_group_ranks(arg)
|
|
114
|
+
group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8'))
|
|
115
|
+
return f"{group_ranks_hash:08x}"
|
|
116
|
+
|
|
123
117
|
@classmethod
|
|
124
118
|
def get_special_types(cls):
|
|
125
119
|
return super().get_special_types() + cls.mindspore_special_type
|
|
126
120
|
|
|
121
|
+
def dump_async_data(self):
|
|
122
|
+
for file_path, tensor in self._async_dump_cache.items():
|
|
123
|
+
save_tensor_as_npy(tensor, file_path)
|
|
124
|
+
self._async_dump_cache.clear()
|
|
125
|
+
|
|
127
126
|
def get_stat_info(self, data):
|
|
127
|
+
self.api_register.restore_inner_used_api()
|
|
128
128
|
tensor_stat = TensorStatInfo()
|
|
129
129
|
if data.numel() == 0:
|
|
130
|
-
|
|
130
|
+
stat_info = tensor_stat
|
|
131
131
|
else:
|
|
132
132
|
if self.config.async_dump:
|
|
133
|
-
|
|
133
|
+
stat_info = MindsporeDataProcessor.get_stat_info_async(data)
|
|
134
134
|
else:
|
|
135
|
-
|
|
135
|
+
stat_info = MindsporeDataProcessor.get_stat_info_sync(data)
|
|
136
|
+
self.api_register.register_inner_used_api()
|
|
137
|
+
return stat_info
|
|
136
138
|
|
|
137
139
|
def analyze_single_element(self, element, suffix_stack):
|
|
138
140
|
if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
|
|
139
141
|
return self.mindspore_object_key[suffix_stack[-1]](element)
|
|
140
142
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
143
|
+
suffix_str = Const.SEP.join(str(s) for s in suffix_stack)
|
|
144
|
+
type_analyzer = [
|
|
145
|
+
(MindsporeDataProcessor.builtin_type, self._analyze_builtin),
|
|
146
|
+
(ms.Tensor, lambda e: self._analyze_tensor(e, suffix_str)),
|
|
147
|
+
(Number, self.analyze_dtype_in_kwargs),
|
|
148
|
+
(MindsporeDataProcessor.np_type[:-1], self._analyze_numpy),
|
|
149
|
+
(np.ndarray, lambda e: self._analyze_ndarray(e, suffix_str)),
|
|
150
|
+
(distributed.P2POp, lambda e: self._analyze_p2pop(e, suffix_str))
|
|
151
|
+
]
|
|
152
|
+
for type_key, analyze_fn in type_analyzer:
|
|
153
|
+
if isinstance(element, type_key):
|
|
154
|
+
return analyze_fn(element)
|
|
152
155
|
return {}
|
|
153
156
|
|
|
157
|
+
def _analyze_p2pop(self, arg, suffix):
|
|
158
|
+
p2pop_info = {"class_type": "mindspore.mint.distributed.P2POp"}
|
|
159
|
+
try:
|
|
160
|
+
tensor_info = self._analyze_tensor(arg.tensor, suffix)
|
|
161
|
+
p2pop_info.update({"tensor": tensor_info})
|
|
162
|
+
p2pop_info.update({"op": arg.op})
|
|
163
|
+
p2pop_info.update({"peer": arg.peer})
|
|
164
|
+
p2pop_info.update({"tag": arg.tag})
|
|
165
|
+
group_id = self.process_group_hash(arg.group) if arg.group else None
|
|
166
|
+
p2pop_info.update({"group_id": group_id})
|
|
167
|
+
except Exception as e:
|
|
168
|
+
logger.warning(f"Failed to parse the P2POp content with error info: {e}.")
|
|
169
|
+
return p2pop_info
|
|
170
|
+
|
|
154
171
|
def _analyze_tensor(self, tensor, suffix):
|
|
155
172
|
tensor_stat = self.get_stat_info(tensor)
|
|
156
173
|
tensor_json = {
|
|
@@ -159,45 +176,54 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
159
176
|
'shape': tensor.shape
|
|
160
177
|
}
|
|
161
178
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
179
|
+
# 将统计值存入全局 buffer,并返回占位索引
|
|
180
|
+
stat_values = [
|
|
181
|
+
tensor_stat.max,
|
|
182
|
+
tensor_stat.min,
|
|
183
|
+
tensor_stat.mean,
|
|
184
|
+
tensor_stat.norm
|
|
185
|
+
]
|
|
186
|
+
|
|
187
|
+
placeholder_index = self.data_writer.append_stat_to_buffer(stat_values)
|
|
188
|
+
|
|
189
|
+
tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index})
|
|
190
|
+
|
|
169
191
|
if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
|
|
170
192
|
tensor_md5 = self.get_md5_for_tensor(tensor)
|
|
171
193
|
tensor_json.update({Const.MD5: tensor_md5})
|
|
172
194
|
return tensor_json
|
|
173
195
|
|
|
174
|
-
|
|
175
|
-
class StatisticsDataProcessor(MindsporeDataProcessor):
|
|
176
|
-
pass
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
class TensorDataProcessor(MindsporeDataProcessor):
|
|
180
|
-
def dump_async_data(self):
|
|
181
|
-
for file_path, tensor in self._async_dump_cache.items():
|
|
182
|
-
save_tensor_as_npy(tensor, file_path)
|
|
183
|
-
self._async_dump_cache.clear()
|
|
184
|
-
|
|
185
|
-
def _analyze_tensor(self, tensor, suffix):
|
|
196
|
+
def _analyze_and_save_tensor(self, tensor, suffix):
|
|
186
197
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
187
|
-
single_arg =
|
|
198
|
+
single_arg = MindsporeDataProcessor._analyze_tensor(self, tensor, suffix)
|
|
188
199
|
single_arg.update({"data_name": dump_data_name})
|
|
189
200
|
if self.config.async_dump:
|
|
190
201
|
self._async_dump_cache[file_path] = tensor.copy()
|
|
191
202
|
else:
|
|
192
203
|
save_tensor_as_npy(tensor, file_path)
|
|
193
204
|
return single_arg
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class StatisticsDataProcessor(MindsporeDataProcessor):
|
|
208
|
+
def _analyze_tensor(self, tensor, suffix):
|
|
209
|
+
if any(item in self.current_api_or_module_name for item in self.config.tensor_list):
|
|
210
|
+
return self._analyze_and_save_tensor(tensor, suffix)
|
|
211
|
+
else:
|
|
212
|
+
return super()._analyze_tensor(tensor, suffix)
|
|
213
|
+
|
|
214
|
+
def _analyze_ndarray(self, ndarray, suffix):
|
|
215
|
+
if any(item in self.current_api_or_module_name for item in self.config.tensor_list):
|
|
216
|
+
return self._analyze_and_save_ndarray(ndarray, suffix)
|
|
217
|
+
else:
|
|
218
|
+
return super()._analyze_ndarray(ndarray, suffix)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class TensorDataProcessor(MindsporeDataProcessor):
|
|
222
|
+
def _analyze_tensor(self, tensor, suffix):
|
|
223
|
+
return self._analyze_and_save_tensor(tensor, suffix)
|
|
224
|
+
|
|
225
|
+
def _analyze_ndarray(self, ndarray, suffix):
|
|
226
|
+
return self._analyze_and_save_ndarray(ndarray, suffix)
|
|
201
227
|
|
|
202
228
|
|
|
203
229
|
class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
@@ -262,11 +288,26 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
262
288
|
self.cached_tensors_and_file_paths = {}
|
|
263
289
|
|
|
264
290
|
def _analyze_maybe_overflow_tensor(self, tensor_json):
|
|
265
|
-
|
|
291
|
+
tensor_stat_index = tensor_json.get(Const.TENSOR_STAT_INDEX)
|
|
292
|
+
if tensor_stat_index is None:
|
|
293
|
+
logger.warning("tensor_stat_index does not exist in tensor_json.")
|
|
294
|
+
return
|
|
295
|
+
max_tensor = self.data_writer.get_buffer_values_max(tensor_stat_index)
|
|
296
|
+
min_tensor = self.data_writer.get_buffer_values_min(tensor_stat_index)
|
|
297
|
+
if max_tensor is None or min_tensor is None:
|
|
266
298
|
return
|
|
267
|
-
|
|
299
|
+
|
|
300
|
+
def check_inf_nan(value):
|
|
301
|
+
# Use .item() if it's a tensor-like structure
|
|
302
|
+
if hasattr(value, "item"):
|
|
303
|
+
value = value.item()
|
|
304
|
+
return np.isinf(value) or np.isnan(value)
|
|
305
|
+
|
|
306
|
+
if check_inf_nan(max_tensor):
|
|
268
307
|
self.has_overflow = True
|
|
269
|
-
|
|
308
|
+
return
|
|
309
|
+
|
|
310
|
+
if check_inf_nan(min_tensor):
|
|
270
311
|
self.has_overflow = True
|
|
271
312
|
|
|
272
313
|
def _analyze_tensor(self, tensor, suffix):
|