mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -13,19 +13,24 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import hashlib
|
|
16
17
|
import zlib
|
|
17
18
|
from dataclasses import asdict
|
|
18
19
|
from typing import List
|
|
19
20
|
|
|
20
21
|
import numpy as np
|
|
21
22
|
import torch
|
|
23
|
+
from torch import distributed as dist
|
|
24
|
+
|
|
22
25
|
from msprobe.core.common.const import Const
|
|
23
26
|
from msprobe.core.common.file_utils import path_len_exceeds_limit
|
|
24
27
|
from msprobe.core.common.log import logger
|
|
28
|
+
from msprobe.core.common.utils import convert_tuple
|
|
25
29
|
from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
|
|
26
30
|
ModuleForwardInputsOutputs, TensorStatInfo
|
|
27
31
|
from msprobe.pytorch.common.utils import save_pt, load_pt
|
|
28
32
|
from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
|
|
33
|
+
from msprobe.core.common.utils import recursion_depth_decorator
|
|
29
34
|
|
|
30
35
|
is_gpu = False
|
|
31
36
|
try:
|
|
@@ -35,7 +40,13 @@ except ImportError:
|
|
|
35
40
|
|
|
36
41
|
|
|
37
42
|
class PytorchDataProcessor(BaseDataProcessor):
|
|
38
|
-
pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor)
|
|
43
|
+
pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor, torch.memory_format, dist.ProcessGroup)
|
|
44
|
+
memory_format = {
|
|
45
|
+
torch.contiguous_format: "contiguous_format",
|
|
46
|
+
torch.channels_last: "channels_last",
|
|
47
|
+
torch.channels_last_3d: "channels_last_3d",
|
|
48
|
+
torch.preserve_format: "preserve_format"
|
|
49
|
+
}
|
|
39
50
|
|
|
40
51
|
def __init__(self, config, data_writer):
|
|
41
52
|
super().__init__(config, data_writer)
|
|
@@ -43,6 +54,7 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
43
54
|
"device": self.analyze_device_in_kwargs,
|
|
44
55
|
"dtype": self.analyze_dtype_in_kwargs
|
|
45
56
|
}
|
|
57
|
+
self._async_dump_cache = {}
|
|
46
58
|
|
|
47
59
|
@staticmethod
|
|
48
60
|
def get_md5_for_tensor(x):
|
|
@@ -71,53 +83,114 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
71
83
|
return {"type": "torch.dtype", "value": str(element)}
|
|
72
84
|
|
|
73
85
|
@staticmethod
|
|
74
|
-
def
|
|
86
|
+
def get_stat_info_async(data):
|
|
75
87
|
tensor_stat = TensorStatInfo()
|
|
76
|
-
if data
|
|
77
|
-
|
|
78
|
-
data_clone = data.detach()
|
|
79
|
-
if data_clone.numel() == 0:
|
|
88
|
+
if torch.is_complex(data):
|
|
89
|
+
logger.warning("Async dump do not support complex data!")
|
|
80
90
|
return tensor_stat
|
|
81
|
-
elif
|
|
82
|
-
tensor_stat.
|
|
83
|
-
|
|
84
|
-
elif not
|
|
85
|
-
tensor_stat.
|
|
86
|
-
|
|
87
|
-
|
|
91
|
+
elif data.dtype == torch.bool:
|
|
92
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min"], torch.stack(
|
|
93
|
+
[torch.any(data), torch.all(data)]))
|
|
94
|
+
elif not data.shape:
|
|
95
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([data, data, data, data]))
|
|
96
|
+
else:
|
|
97
|
+
if not data.is_floating_point() or data.dtype == torch.float64:
|
|
98
|
+
data = data.float()
|
|
99
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([
|
|
100
|
+
torch.max(data),
|
|
101
|
+
torch.min(data),
|
|
102
|
+
torch.mean(data),
|
|
103
|
+
torch.norm(data)
|
|
104
|
+
]))
|
|
105
|
+
return tensor_stat
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def get_stat_info_sync(data):
|
|
109
|
+
tensor_stat = TensorStatInfo()
|
|
110
|
+
if torch.is_complex(data):
|
|
111
|
+
data_np = data.cpu().numpy()
|
|
88
112
|
data_abs = np.abs(data_np)
|
|
89
113
|
tensor_stat.max = np.max(data_abs).item()
|
|
90
114
|
tensor_stat.min = np.min(data_abs).item()
|
|
91
115
|
tensor_stat.mean = np.mean(data_abs).item()
|
|
116
|
+
elif data.dtype == torch.bool:
|
|
117
|
+
tensor_stat.max = torch.any(data).item()
|
|
118
|
+
tensor_stat.min = torch.all(data).item()
|
|
119
|
+
elif not data.shape:
|
|
120
|
+
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
|
|
92
121
|
else:
|
|
93
|
-
if not
|
|
94
|
-
|
|
95
|
-
tensor_stat.max = torch.
|
|
96
|
-
tensor_stat.min = torch.
|
|
97
|
-
tensor_stat.mean = torch.
|
|
98
|
-
tensor_stat.norm = torch.
|
|
122
|
+
if not data.is_floating_point() or data.dtype == torch.float64:
|
|
123
|
+
data = data.float()
|
|
124
|
+
tensor_stat.max = torch.max(data).item()
|
|
125
|
+
tensor_stat.min = torch.min(data).item()
|
|
126
|
+
tensor_stat.mean = torch.mean(data).item()
|
|
127
|
+
tensor_stat.norm = torch.norm(data).item()
|
|
99
128
|
return tensor_stat
|
|
100
129
|
|
|
130
|
+
@staticmethod
|
|
131
|
+
def get_stat_info(data, async_dump=False):
|
|
132
|
+
tensor_stat = TensorStatInfo()
|
|
133
|
+
if data.is_meta:
|
|
134
|
+
return tensor_stat
|
|
135
|
+
data_clone = data.detach()
|
|
136
|
+
if data_clone.numel() == 0:
|
|
137
|
+
return tensor_stat
|
|
138
|
+
else:
|
|
139
|
+
if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump:
|
|
140
|
+
return PytorchDataProcessor.get_stat_info_sync(data_clone)
|
|
141
|
+
else:
|
|
142
|
+
return PytorchDataProcessor.get_stat_info_async(data_clone)
|
|
143
|
+
|
|
101
144
|
@staticmethod
|
|
102
145
|
def handle_tensor_extremum_nan_inf(tensor, operator):
|
|
103
146
|
data_clone = tensor.detach()
|
|
104
|
-
data_nan = torch.
|
|
105
|
-
if int(torch.
|
|
147
|
+
data_nan = torch.isnan(data_clone)
|
|
148
|
+
if int(torch.sum(data_nan)) == data_clone.numel():
|
|
106
149
|
return float('nan')
|
|
107
|
-
|
|
108
|
-
|
|
150
|
+
|
|
151
|
+
finite_mask = torch.isfinite(data_clone)
|
|
152
|
+
if int(torch.sum(finite_mask)) > 0:
|
|
109
153
|
finite_values = data_clone[finite_mask]
|
|
110
|
-
return torch.
|
|
111
|
-
torch.
|
|
154
|
+
return torch.max(finite_values).item() if operator == 'max' else \
|
|
155
|
+
torch.min(finite_values).item()
|
|
112
156
|
else:
|
|
113
157
|
data_no_nan = data_clone[~data_nan]
|
|
114
|
-
return torch.
|
|
115
|
-
torch.
|
|
158
|
+
return torch.max(data_no_nan).item() if operator == 'max' else \
|
|
159
|
+
torch.min(data_no_nan).item()
|
|
160
|
+
|
|
161
|
+
@staticmethod
|
|
162
|
+
def process_group_hash(arg):
|
|
163
|
+
group_ranks = dist.get_process_group_ranks(arg)
|
|
164
|
+
group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest()
|
|
165
|
+
return group_ranks_hash
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def is_distributed_op(module):
|
|
169
|
+
return getattr(module, "op_is_distributed", False)
|
|
116
170
|
|
|
117
171
|
@staticmethod
|
|
118
172
|
def _analyze_torch_size(arg):
|
|
119
173
|
return {"type": "torch.Size", "value": list(arg)}
|
|
120
174
|
|
|
175
|
+
@staticmethod
|
|
176
|
+
def _analyze_memory_format(arg):
|
|
177
|
+
# 获取内存格式
|
|
178
|
+
format_type = PytorchDataProcessor.memory_format.get(arg)
|
|
179
|
+
|
|
180
|
+
return {"type": "torch.memory_format", "format": format_type}
|
|
181
|
+
|
|
182
|
+
@staticmethod
|
|
183
|
+
def _analyze_process_group(arg):
|
|
184
|
+
group_info = {"type": "torch.ProcessGroup"}
|
|
185
|
+
try:
|
|
186
|
+
group_ranks = dist.get_process_group_ranks(arg)
|
|
187
|
+
group_info.update({"group_ranks": group_ranks})
|
|
188
|
+
group_id = PytorchDataProcessor.process_group_hash(arg)
|
|
189
|
+
group_info.update({"group_id": group_id})
|
|
190
|
+
except Exception as e:
|
|
191
|
+
logger.warning(f"Failed to get process group(id: {group_id}) ranks info with error info: {e}.")
|
|
192
|
+
return group_info
|
|
193
|
+
|
|
121
194
|
@classmethod
|
|
122
195
|
def get_special_types(cls):
|
|
123
196
|
return super().get_special_types() + cls.pytorch_special_type
|
|
@@ -127,6 +200,10 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
127
200
|
return self.torch_object_key[suffix_stack[-1]](element)
|
|
128
201
|
if isinstance(element, torch.Size):
|
|
129
202
|
return self._analyze_torch_size(element)
|
|
203
|
+
if isinstance(element, torch.memory_format):
|
|
204
|
+
return self._analyze_memory_format(element)
|
|
205
|
+
if isinstance(element, dist.ProcessGroup):
|
|
206
|
+
return self._analyze_process_group(element)
|
|
130
207
|
converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
|
|
131
208
|
if converted_numpy is not element:
|
|
132
209
|
return self._analyze_numpy(converted_numpy, numpy_type)
|
|
@@ -136,26 +213,35 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
136
213
|
return self._analyze_builtin(element)
|
|
137
214
|
return {}
|
|
138
215
|
|
|
216
|
+
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
217
|
+
if self.is_distributed_op(module):
|
|
218
|
+
module_input_output.update_output_with_args_and_kwargs()
|
|
219
|
+
return super().analyze_forward_output(name, module, module_input_output)
|
|
220
|
+
|
|
139
221
|
def _analyze_tensor(self, tensor, suffix):
|
|
140
|
-
tensor_stat = self.get_stat_info(tensor)
|
|
222
|
+
tensor_stat = self.get_stat_info(tensor, self.config.async_dump)
|
|
141
223
|
tensor_json = {}
|
|
142
224
|
tensor_json.update({'type': 'torch.Tensor'})
|
|
143
225
|
tensor_json.update({'dtype': str(tensor.dtype)})
|
|
144
226
|
tensor_json.update({"shape": tensor.shape})
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
227
|
+
if tensor_stat.stack_tensor_stat is None:
|
|
228
|
+
tensor_json.update({"Max": tensor_stat.max})
|
|
229
|
+
tensor_json.update({"Min": tensor_stat.min})
|
|
230
|
+
tensor_json.update({"Mean": tensor_stat.mean})
|
|
231
|
+
tensor_json.update({"Norm": tensor_stat.norm})
|
|
232
|
+
tensor_json.update({"requires_grad": tensor.requires_grad})
|
|
233
|
+
if tensor_stat.max is not None:
|
|
234
|
+
if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
|
|
235
|
+
tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
|
|
236
|
+
if tensor_stat.min is not None:
|
|
237
|
+
if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
|
|
238
|
+
tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
|
|
239
|
+
|
|
240
|
+
else:
|
|
241
|
+
tensor_json.update({"requires_grad": tensor.requires_grad})
|
|
242
|
+
tensor_json.update({"tensor_stat": tensor_stat.stack_tensor_stat})
|
|
243
|
+
|
|
244
|
+
if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
|
|
159
245
|
tensor_md5 = self.get_md5_for_tensor(tensor)
|
|
160
246
|
tensor_json.update({Const.MD5: tensor_md5})
|
|
161
247
|
return tensor_json
|
|
@@ -166,12 +252,20 @@ class StatisticsDataProcessor(PytorchDataProcessor):
|
|
|
166
252
|
|
|
167
253
|
|
|
168
254
|
class TensorDataProcessor(PytorchDataProcessor):
|
|
255
|
+
def dump_async_data(self):
|
|
256
|
+
for file_path, tensor in self._async_dump_cache.items():
|
|
257
|
+
save_pt(tensor.contiguous(), file_path)
|
|
258
|
+
self._async_dump_cache.clear()
|
|
259
|
+
|
|
169
260
|
def _analyze_tensor(self, tensor, suffix):
|
|
170
261
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
171
|
-
saved_tensor = tensor.clone().contiguous().detach()
|
|
172
|
-
save_pt(saved_tensor, file_path)
|
|
173
262
|
single_arg = super()._analyze_tensor(tensor, suffix)
|
|
174
263
|
single_arg.update({"data_name": dump_data_name})
|
|
264
|
+
if self.config.async_dump:
|
|
265
|
+
self._async_dump_cache[file_path] = tensor.clone().detach()
|
|
266
|
+
else:
|
|
267
|
+
saved_tensor = tensor.clone().contiguous().detach()
|
|
268
|
+
save_pt(saved_tensor, file_path)
|
|
175
269
|
return single_arg
|
|
176
270
|
|
|
177
271
|
|
|
@@ -182,7 +276,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
182
276
|
super().__init__(config, data_writer)
|
|
183
277
|
self.has_overflow = False
|
|
184
278
|
self.support_inf_nan = None
|
|
185
|
-
self.
|
|
279
|
+
self.cached_api_info = {}
|
|
186
280
|
self.cached_tensors_and_file_paths = {}
|
|
187
281
|
self.bits_for_overflow = 8
|
|
188
282
|
self.real_overflow_nums = 0
|
|
@@ -196,21 +290,21 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
196
290
|
return True
|
|
197
291
|
return False
|
|
198
292
|
|
|
199
|
-
def
|
|
293
|
+
def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
200
294
|
self.has_overflow = False
|
|
201
295
|
self._is_support_inf_nan()
|
|
202
|
-
self.
|
|
296
|
+
self.cached_api_info = super().analyze_forward_input(name, module, module_input_output)
|
|
203
297
|
return None
|
|
204
298
|
|
|
205
|
-
def
|
|
299
|
+
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
206
300
|
self._is_support_inf_nan()
|
|
207
|
-
api_info_struct = super().
|
|
208
|
-
if name in self.
|
|
209
|
-
self.
|
|
301
|
+
api_info_struct = super().analyze_forward_output(name, module, module_input_output)
|
|
302
|
+
if name in self.cached_api_info and name in api_info_struct:
|
|
303
|
+
self.cached_api_info[name].update(api_info_struct[name])
|
|
210
304
|
elif name in api_info_struct:
|
|
211
|
-
self.
|
|
305
|
+
self.cached_api_info = api_info_struct
|
|
212
306
|
self.handle_overflow()
|
|
213
|
-
return self.
|
|
307
|
+
return self.cached_api_info if self.has_overflow else None
|
|
214
308
|
|
|
215
309
|
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
216
310
|
self.has_overflow = False
|
|
@@ -225,6 +319,13 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
225
319
|
api_info_struct = super().analyze_backward(name, module, module_input_output)
|
|
226
320
|
self.handle_overflow()
|
|
227
321
|
return api_info_struct if self.has_overflow else None
|
|
322
|
+
|
|
323
|
+
def analyze_params(self, name, param_name, grad):
|
|
324
|
+
self.has_overflow = False
|
|
325
|
+
self._is_support_inf_nan()
|
|
326
|
+
api_info_struct = super().analyze_params(name, param_name, grad)
|
|
327
|
+
self.handle_overflow()
|
|
328
|
+
return api_info_struct if self.has_overflow else None
|
|
228
329
|
|
|
229
330
|
def handle_overflow(self):
|
|
230
331
|
if not self.support_inf_nan:
|
|
@@ -299,10 +400,10 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor):
|
|
|
299
400
|
)
|
|
300
401
|
return
|
|
301
402
|
|
|
302
|
-
def
|
|
403
|
+
def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
303
404
|
self.checker.pre_forward(name, module, self, module_input_output.args, module_input_output.kwargs)
|
|
304
405
|
|
|
305
|
-
def
|
|
406
|
+
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
306
407
|
new_output, unequal_rows = self.checker.forward(
|
|
307
408
|
name,
|
|
308
409
|
module,
|
|
@@ -320,64 +421,120 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor):
|
|
|
320
421
|
|
|
321
422
|
|
|
322
423
|
class KernelDumpDataProcessor(PytorchDataProcessor):
|
|
323
|
-
forward_init_status = False
|
|
324
|
-
multi_output_apis = ["_sort_", "npu_flash_attention"]
|
|
325
|
-
|
|
326
424
|
def __init__(self, config, data_writer):
|
|
327
425
|
super().__init__(config, data_writer)
|
|
426
|
+
self.enable_kernel_dump = True
|
|
427
|
+
self.is_found_output_tensor = False
|
|
428
|
+
self.is_found_grad_input_tensor = False
|
|
429
|
+
self.forward_args = None
|
|
430
|
+
self.forward_kwargs = None
|
|
431
|
+
self.forward_output_tensor = None
|
|
432
|
+
self.grad_input_tensor = None
|
|
433
|
+
|
|
434
|
+
@staticmethod
|
|
435
|
+
def start_kernel_dump(config_path):
|
|
436
|
+
torch_npu.npu.synchronize()
|
|
437
|
+
torch_npu.npu.init_dump()
|
|
438
|
+
torch_npu.npu.set_dump(config_path)
|
|
439
|
+
torch_npu.npu.synchronize()
|
|
440
|
+
|
|
441
|
+
@staticmethod
|
|
442
|
+
def stop_kernel_dump():
|
|
443
|
+
torch_npu.npu.synchronize()
|
|
444
|
+
torch_npu.npu.finalize_dump()
|
|
445
|
+
torch_npu.npu.synchronize()
|
|
446
|
+
|
|
447
|
+
@staticmethod
|
|
448
|
+
def _print_unsupported_log(api_name):
|
|
449
|
+
logger.warning(f"The kernel dump does not support the {api_name} API.")
|
|
450
|
+
|
|
451
|
+
def analyze_forward_input(self, name, module, module_input_output):
|
|
452
|
+
if not self.enable_kernel_dump:
|
|
453
|
+
return
|
|
454
|
+
if is_gpu:
|
|
455
|
+
logger.warning("The current environment is not a complete NPU environment, and kernel dump cannot be used.")
|
|
456
|
+
self.enable_kernel_dump = False
|
|
457
|
+
return
|
|
458
|
+
|
|
459
|
+
if self.config.is_backward_kernel_dump:
|
|
460
|
+
self.forward_args = self.clone_and_detach_tensor(module_input_output.args)
|
|
461
|
+
self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs)
|
|
462
|
+
try:
|
|
463
|
+
output = module.forward(*self.forward_args, **self.forward_kwargs)
|
|
464
|
+
except Exception:
|
|
465
|
+
self._print_unsupported_log(name)
|
|
466
|
+
self.enable_kernel_dump = False
|
|
467
|
+
return
|
|
468
|
+
|
|
469
|
+
self.analyze_element(convert_tuple(output))
|
|
470
|
+
if not self.is_found_output_tensor:
|
|
471
|
+
self._print_unsupported_log(name)
|
|
472
|
+
self.enable_kernel_dump = False
|
|
473
|
+
return
|
|
474
|
+
self.start_kernel_dump(self.config.kernel_config_path)
|
|
475
|
+
|
|
476
|
+
def analyze_forward_output(self, name, module, module_input_output):
|
|
477
|
+
if not self.enable_kernel_dump:
|
|
478
|
+
return
|
|
479
|
+
if self.config.is_backward_kernel_dump:
|
|
480
|
+
return
|
|
481
|
+
self.enable_kernel_dump = False
|
|
482
|
+
self.stop_kernel_dump()
|
|
483
|
+
logger.info(f"The kernel data of {name} is dumped successfully.")
|
|
484
|
+
|
|
485
|
+
def analyze_backward(self, name, module, module_input_output):
|
|
486
|
+
if not self.enable_kernel_dump:
|
|
487
|
+
return
|
|
488
|
+
self.enable_kernel_dump = False
|
|
489
|
+
|
|
490
|
+
self.analyze_element(module_input_output.grad_input)
|
|
491
|
+
if not self.is_found_grad_input_tensor:
|
|
492
|
+
self._print_unsupported_log(name)
|
|
493
|
+
return
|
|
494
|
+
self.start_kernel_dump(self.config.kernel_config_path)
|
|
495
|
+
|
|
496
|
+
try:
|
|
497
|
+
self.forward_output_tensor.backward(self.grad_input_tensor, retain_graph=True)
|
|
498
|
+
except Exception:
|
|
499
|
+
self._print_unsupported_log(name)
|
|
500
|
+
self.stop_kernel_dump()
|
|
501
|
+
return
|
|
328
502
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
503
|
+
self.stop_kernel_dump()
|
|
504
|
+
logger.info(f"The kernel data of {name} is dumped successfully.")
|
|
505
|
+
|
|
506
|
+
@recursion_depth_decorator("KernelDump: KernelDumpDataProcessor.clone_and_detach_tensor")
|
|
507
|
+
def clone_and_detach_tensor(self, input_params):
|
|
508
|
+
if isinstance(input_params, torch.Tensor):
|
|
509
|
+
if input_params.requires_grad:
|
|
510
|
+
return input_params.clone().detach().requires_grad_()
|
|
511
|
+
return input_params.clone()
|
|
512
|
+
elif isinstance(input_params, tuple):
|
|
513
|
+
return tuple(self.clone_and_detach_tensor(x) for x in input_params)
|
|
514
|
+
elif isinstance(input_params, list):
|
|
515
|
+
return list(self.clone_and_detach_tensor(x) for x in input_params)
|
|
516
|
+
elif isinstance(input_params, dict):
|
|
517
|
+
return {k: self.clone_and_detach_tensor(v) for k, v in input_params.items()}
|
|
332
518
|
else:
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
def forward_acl_dump(self, name, module, module_input_output):
|
|
336
|
-
if not KernelDumpDataProcessor.forward_init_status:
|
|
337
|
-
KernelDumpDataProcessor.forward_init_status = True
|
|
338
|
-
torch_npu.npu.synchronize()
|
|
339
|
-
torch_npu.npu.init_dump()
|
|
340
|
-
torch_npu.npu.set_dump(self.config.acl_config)
|
|
341
|
-
torch_npu.npu.synchronize()
|
|
342
|
-
if self.op_need_trigger(name):
|
|
343
|
-
module.forward(*module_input_output.args, **module_input_output.kwargs).cpu()
|
|
344
|
-
else:
|
|
345
|
-
module.forward(*module_input_output.args, **module_input_output.kwargs)
|
|
346
|
-
torch_npu.npu.synchronize()
|
|
347
|
-
torch_npu.npu.finalize_dump()
|
|
348
|
-
torch_npu.npu.synchronize()
|
|
349
|
-
KernelDumpDataProcessor.forward_init_status = False
|
|
350
|
-
logger.info("Dump %s op file." % name)
|
|
351
|
-
|
|
352
|
-
def acl_backward_dump_status(self, output, grad, module_name):
|
|
353
|
-
if isinstance(output, torch.Tensor):
|
|
354
|
-
output.backward(grad, retain_graph=True)
|
|
355
|
-
return True
|
|
519
|
+
return input_params
|
|
356
520
|
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
521
|
+
def analyze_single_element(self, element, suffix_stack):
|
|
522
|
+
if isinstance(element, torch.Tensor):
|
|
523
|
+
if not self.is_found_output_tensor:
|
|
524
|
+
if element.requires_grad:
|
|
525
|
+
self.forward_output_tensor = element
|
|
526
|
+
self.is_found_output_tensor = True
|
|
527
|
+
return {}
|
|
528
|
+
if not self.is_found_grad_input_tensor:
|
|
529
|
+
self.grad_input_tensor = element.clone()
|
|
530
|
+
self.is_found_grad_input_tensor = True
|
|
531
|
+
return {}
|
|
362
532
|
|
|
363
|
-
def
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
torch_npu.npu.set_dump(self.config.acl_config)
|
|
372
|
-
torch_npu.npu.synchronize()
|
|
373
|
-
if not self.acl_backward_dump_status(output, grad, name):
|
|
374
|
-
logger.warning("The output of {} is not of tensor type and cannot be automatically derived. "
|
|
375
|
-
"you can manually construct a single API backward case for ACL dump.".format(
|
|
376
|
-
name))
|
|
377
|
-
torch_npu.npu.synchronize()
|
|
378
|
-
torch_npu.npu.finalize_dump()
|
|
379
|
-
KernelDumpDataProcessor.forward_init_status = False
|
|
380
|
-
logger.info("Dump %s op file." % name)
|
|
381
|
-
|
|
382
|
-
def op_need_trigger(self, module_name):
|
|
383
|
-
return 'Tensor.__getitem__.' in module_name
|
|
533
|
+
def reset_status(self):
|
|
534
|
+
self.enable_kernel_dump = True
|
|
535
|
+
self.is_found_output_tensor = False
|
|
536
|
+
self.is_found_grad_input_tensor = False
|
|
537
|
+
self.forward_args = None
|
|
538
|
+
self.forward_kwargs = None
|
|
539
|
+
self.forward_output_tensor = None
|
|
540
|
+
self.grad_input_tensor = None
|
|
@@ -15,10 +15,12 @@
|
|
|
15
15
|
|
|
16
16
|
import csv
|
|
17
17
|
import os
|
|
18
|
+
import numpy as np
|
|
18
19
|
|
|
19
20
|
from msprobe.core.common.const import Const, FileCheckConst
|
|
20
|
-
from msprobe.core.common.file_utils import change_mode, FileOpen, save_json
|
|
21
|
+
from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json
|
|
21
22
|
from msprobe.core.common.log import logger
|
|
23
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
class DataWriter:
|
|
@@ -115,3 +117,29 @@ class DataWriter:
|
|
|
115
117
|
self.write_stack_info_json(self.stack_file_path)
|
|
116
118
|
if self.cache_construct:
|
|
117
119
|
self.write_construct_info_json(self.construct_file_path)
|
|
120
|
+
|
|
121
|
+
def fill_stack_tensor_data(self):
|
|
122
|
+
self.process_stat_data_recursive(self.cache_data)
|
|
123
|
+
|
|
124
|
+
def process_stat_data_recursive(self, data, depth=0):
|
|
125
|
+
if depth > Const.MAX_DEPTH:
|
|
126
|
+
logger.error(f"The maximum depth of recursive process stat data, {Const.MAX_DEPTH} is reached.")
|
|
127
|
+
raise MsprobeException(MsprobeException.RECURSION_LIMIT_ERROR)
|
|
128
|
+
if isinstance(data, dict):
|
|
129
|
+
if "tensor_stat" in data.keys():
|
|
130
|
+
tensor_stat = data["tensor_stat"]
|
|
131
|
+
if len(tensor_stat) != Const.TENSOR_STAT_LEN or len(tensor_stat[0]) != len(tensor_stat[1]):
|
|
132
|
+
logger.warning("Some bad data in async dump")
|
|
133
|
+
else:
|
|
134
|
+
tensor_stat_index, tensor_stat_data = tensor_stat[0], tensor_stat[1]
|
|
135
|
+
if hasattr(tensor_stat_data, "device") and tensor_stat_data.device != Const.CPU_LOWERCASE:
|
|
136
|
+
tensor_stat_data = tensor_stat_data.cpu()
|
|
137
|
+
for index, stat in zip(tensor_stat_index, tensor_stat_data):
|
|
138
|
+
data.update({index, stat.item()})
|
|
139
|
+
del data["tensor_stat"]
|
|
140
|
+
else:
|
|
141
|
+
for key in data.keys():
|
|
142
|
+
self.process_stat_data_recursive(data[key], depth + 1)
|
|
143
|
+
elif isinstance(data, (list, tuple)):
|
|
144
|
+
for i in data:
|
|
145
|
+
self.process_stat_data_recursive(i, depth + 1)
|