mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import copy
|
|
1
2
|
import os
|
|
2
3
|
import zlib
|
|
3
4
|
from dataclasses import asdict
|
|
@@ -5,18 +6,20 @@ from typing import List
|
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import torch
|
|
8
|
-
from msprobe.core.common.exceptions import MsaccException
|
|
9
9
|
from msprobe.core.common.file_check import path_len_exceeds_limit, change_mode
|
|
10
10
|
from msprobe.core.common.log import logger
|
|
11
11
|
from msprobe.core.common.const import Const, OverflowConst, FileCheckConst
|
|
12
12
|
from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
|
|
13
13
|
ModuleForwardInputsOutputs, TensorStatInfo
|
|
14
14
|
from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
|
|
15
|
+
from msprobe.pytorch.common.utils import save_pt
|
|
16
|
+
|
|
15
17
|
|
|
16
18
|
try:
|
|
17
19
|
import torch_npu
|
|
20
|
+
is_gpu = False
|
|
18
21
|
except ImportError:
|
|
19
|
-
|
|
22
|
+
is_gpu = True
|
|
20
23
|
|
|
21
24
|
|
|
22
25
|
class PytorchDataProcessor(BaseDataProcessor):
|
|
@@ -68,6 +71,12 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
68
71
|
tensor_stat.min = False not in data_clone
|
|
69
72
|
elif not data_clone.shape:
|
|
70
73
|
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.item()
|
|
74
|
+
elif torch.is_complex(data_clone):
|
|
75
|
+
data_np = data_clone.cpu().numpy()
|
|
76
|
+
data_abs = np.abs(data_np)
|
|
77
|
+
tensor_stat.max = np.max(data_abs).item()
|
|
78
|
+
tensor_stat.min = np.min(data_abs).item()
|
|
79
|
+
tensor_stat.mean = np.mean(data_abs).item()
|
|
71
80
|
else:
|
|
72
81
|
if not data_clone.is_floating_point() or data_clone.dtype == torch.float64:
|
|
73
82
|
data_clone = data_clone.float()
|
|
@@ -76,7 +85,39 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
76
85
|
tensor_stat.mean = torch._C._VariableFunctionsClass.mean(data_clone).item()
|
|
77
86
|
tensor_stat.norm = torch._C._VariableFunctionsClass.norm(data_clone).item()
|
|
78
87
|
return tensor_stat
|
|
79
|
-
|
|
88
|
+
|
|
89
|
+
@staticmethod
|
|
90
|
+
def handle_tensor_extremum_nan_inf(tensor, operator):
|
|
91
|
+
data_clone = tensor.detach()
|
|
92
|
+
data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
|
|
93
|
+
if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
|
|
94
|
+
return float('nan')
|
|
95
|
+
finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
|
|
96
|
+
if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
|
|
97
|
+
finite_values = data_clone[finite_mask]
|
|
98
|
+
return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
|
|
99
|
+
torch._C._VariableFunctionsClass.min(finite_values).item()
|
|
100
|
+
else:
|
|
101
|
+
data_no_nan = data_clone[~data_nan]
|
|
102
|
+
return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
|
|
103
|
+
torch._C._VariableFunctionsClass.min(data_no_nan).item()
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def _analyze_builtin(arg):
|
|
107
|
+
single_arg = {}
|
|
108
|
+
if isinstance(arg, slice):
|
|
109
|
+
single_arg.update({"type": "slice"})
|
|
110
|
+
# slice参数中可能存在tensor类型,json序列化,需要转换为python数值类型
|
|
111
|
+
values = [
|
|
112
|
+
value if not isinstance(value, torch.Tensor) else value.item()
|
|
113
|
+
for value in [arg.start, arg.stop, arg.step]
|
|
114
|
+
]
|
|
115
|
+
single_arg.update({"value": values})
|
|
116
|
+
else:
|
|
117
|
+
single_arg.update({"type": type(arg).__name__})
|
|
118
|
+
single_arg.update({"value": arg})
|
|
119
|
+
return single_arg
|
|
120
|
+
|
|
80
121
|
@staticmethod
|
|
81
122
|
def _analyze_torch_size(arg):
|
|
82
123
|
return {"type": "torch.Size", "value": list(arg)}
|
|
@@ -97,10 +138,7 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
97
138
|
return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
|
|
98
139
|
if isinstance(element, (bool, int, float, str, slice)):
|
|
99
140
|
return self._analyze_builtin(element)
|
|
100
|
-
return
|
|
101
|
-
|
|
102
|
-
def analyze_element(self, element):
|
|
103
|
-
return self.recursive_apply_transform(element, self.analyze_single_element)
|
|
141
|
+
return {}
|
|
104
142
|
|
|
105
143
|
def _analyze_tensor(self, tensor, suffix):
|
|
106
144
|
tensor_stat = self.get_stat_info(tensor)
|
|
@@ -113,9 +151,17 @@ class PytorchDataProcessor(BaseDataProcessor):
|
|
|
113
151
|
tensor_json.update({"Mean": tensor_stat.mean})
|
|
114
152
|
tensor_json.update({"Norm": tensor_stat.norm})
|
|
115
153
|
tensor_json.update({"requires_grad": tensor.requires_grad})
|
|
116
|
-
|
|
154
|
+
|
|
155
|
+
if tensor_stat.max is not None:
|
|
156
|
+
if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
|
|
157
|
+
tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
|
|
158
|
+
if tensor_stat.min is not None:
|
|
159
|
+
if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
|
|
160
|
+
tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
|
|
161
|
+
|
|
162
|
+
if self.config.summary_mode == Const.MD5:
|
|
117
163
|
tensor_md5 = self.get_md5_for_tensor(tensor)
|
|
118
|
-
tensor_json.update({
|
|
164
|
+
tensor_json.update({Const.MD5: tensor_md5})
|
|
119
165
|
return tensor_json
|
|
120
166
|
|
|
121
167
|
|
|
@@ -126,11 +172,8 @@ class StatisticsDataProcessor(PytorchDataProcessor):
|
|
|
126
172
|
class TensorDataProcessor(PytorchDataProcessor):
|
|
127
173
|
def _analyze_tensor(self, tensor, suffix):
|
|
128
174
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
132
|
-
else:
|
|
133
|
-
logger.warning(f'The file path {file_path} length exceeds limit.')
|
|
175
|
+
saved_tensor = tensor.contiguous().detach()
|
|
176
|
+
save_pt(saved_tensor, file_path)
|
|
134
177
|
single_arg = super()._analyze_tensor(tensor, suffix)
|
|
135
178
|
single_arg.update({"data_name": dump_data_name})
|
|
136
179
|
return single_arg
|
|
@@ -142,29 +185,36 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
142
185
|
def __init__(self, config, data_writer):
|
|
143
186
|
super().__init__(config, data_writer)
|
|
144
187
|
self.cached_tensors_and_file_paths = {}
|
|
145
|
-
self.real_overflow_dump_times = 0
|
|
146
|
-
self.overflow_nums = config.overflow_num
|
|
147
188
|
self.bits_for_overflow = 8
|
|
189
|
+
self.real_overflow_nums = 0
|
|
190
|
+
self.overflow_nums = config.overflow_nums
|
|
191
|
+
self.forward_inplace_inputs = None
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def is_terminated(self):
|
|
195
|
+
if self.overflow_nums == -1:
|
|
196
|
+
return False
|
|
197
|
+
if self.real_overflow_nums >= self.overflow_nums:
|
|
198
|
+
logger.info(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_nums}")
|
|
199
|
+
return True
|
|
200
|
+
return False
|
|
148
201
|
|
|
149
202
|
@staticmethod
|
|
150
203
|
def overflow_debug_mode_enable():
|
|
151
204
|
overflow_mode = os.getenv(OverflowConst.OVERFLOW_DEBUG_MODE_ENABLE, Const.ENV_DISABLE)
|
|
152
205
|
return overflow_mode == Const.ENV_ENABLE
|
|
153
206
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
data_no_nan = data_clone[~data_nan]
|
|
166
|
-
return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
|
|
167
|
-
torch._C._VariableFunctionsClass.min(data_no_nan).item()
|
|
207
|
+
def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
|
|
208
|
+
self.forward_inplace_inputs = copy.deepcopy(module_input_output)
|
|
209
|
+
return None
|
|
210
|
+
|
|
211
|
+
def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
|
|
212
|
+
module_input_output.output = module_input_output.concat_args_and_kwargs()
|
|
213
|
+
module_input_output.args = self.forward_inplace_inputs.args
|
|
214
|
+
module_input_output.kwargs = self.forward_inplace_inputs.kwargs
|
|
215
|
+
# release memory used by forward inputs
|
|
216
|
+
self.forward_inplace_inputs = None
|
|
217
|
+
return self.analyze_forward(name, None, module_input_output)
|
|
168
218
|
|
|
169
219
|
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
170
220
|
self.has_overflow = False
|
|
@@ -181,20 +231,12 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
181
231
|
def maybe_save_overflow_data_and_check_overflow_times(self):
|
|
182
232
|
if self.has_overflow:
|
|
183
233
|
for file_path, tensor in self.cached_tensors_and_file_paths.items():
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
self.inc_and_check_overflow_times()
|
|
234
|
+
save_pt(tensor, file_path)
|
|
235
|
+
self.real_overflow_nums += 1
|
|
187
236
|
self.cached_tensors_and_file_paths = {}
|
|
188
237
|
|
|
189
|
-
def inc_and_check_overflow_times(self):
|
|
190
|
-
self.real_overflow_dump_times += 1
|
|
191
|
-
if self.overflow_nums == -1:
|
|
192
|
-
return
|
|
193
|
-
if self.real_overflow_dump_times >= self.overflow_nums:
|
|
194
|
-
raise MsaccException(MsaccException.OVERFLOW_NUMS_ERROR, str(self.real_overflow_dump_times))
|
|
195
|
-
|
|
196
238
|
def check_overflow_npu(self):
|
|
197
|
-
if self.
|
|
239
|
+
if self.overflow_debug_mode_enable():
|
|
198
240
|
float_status = torch.zeros(self.bits_for_overflow).npu()
|
|
199
241
|
result = torch_npu.npu_get_float_status(float_status, OverflowConst.OVERFLOW_DEBUG_MODE)
|
|
200
242
|
if result.cpu()[0] != 0:
|
|
@@ -211,21 +253,22 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
211
253
|
else:
|
|
212
254
|
torch_npu._C._clear_overflow_npu()
|
|
213
255
|
|
|
214
|
-
def _analyze_maybe_overflow_tensor(self, tensor_json
|
|
215
|
-
|
|
216
|
-
if hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan():
|
|
256
|
+
def _analyze_maybe_overflow_tensor(self, tensor_json):
|
|
257
|
+
if is_gpu or (hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan()):
|
|
217
258
|
if tensor_json['Max'] is None:
|
|
218
259
|
return
|
|
219
260
|
if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']):
|
|
220
|
-
tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "max")
|
|
221
261
|
self.has_overflow = True
|
|
222
262
|
if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']):
|
|
223
|
-
tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "min")
|
|
224
263
|
self.has_overflow = True
|
|
225
264
|
else:
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
self.
|
|
265
|
+
try:
|
|
266
|
+
self.has_overflow = self.check_overflow_npu()
|
|
267
|
+
if self.has_overflow:
|
|
268
|
+
self.clear_overflow_npu()
|
|
269
|
+
except Exception as e:
|
|
270
|
+
logger.error(f"Overflow check failed, the current environment may be abnormal.")
|
|
271
|
+
raise RuntimeError(f"overflow check failed") from e
|
|
229
272
|
|
|
230
273
|
def _analyze_tensor(self, tensor, suffix):
|
|
231
274
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
@@ -234,7 +277,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
|
234
277
|
else:
|
|
235
278
|
logger.warning(f'The file path {file_path} length exceeds limit.')
|
|
236
279
|
single_arg = super()._analyze_tensor(tensor, suffix)
|
|
237
|
-
self._analyze_maybe_overflow_tensor(single_arg
|
|
280
|
+
self._analyze_maybe_overflow_tensor(single_arg)
|
|
238
281
|
single_arg.update({"data_name": dump_data_name})
|
|
239
282
|
return single_arg
|
|
240
283
|
|
|
@@ -280,7 +323,7 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor):
|
|
|
280
323
|
self._forward_new_output = new_output
|
|
281
324
|
|
|
282
325
|
def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
|
|
283
|
-
self.checker.backward(name, module, module_input_output.
|
|
326
|
+
self.checker.backward(name, module, module_input_output.grad_input)
|
|
284
327
|
|
|
285
328
|
|
|
286
329
|
class KernelDumpDataProcessor(PytorchDataProcessor):
|
|
@@ -4,7 +4,7 @@ import fcntl
|
|
|
4
4
|
import json
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
|
|
7
|
-
from msprobe.core.common.file_check import change_mode
|
|
7
|
+
from msprobe.core.common.file_check import change_mode, FileOpen
|
|
8
8
|
from msprobe.core.common.log import logger
|
|
9
9
|
from msprobe.core.common.const import Const, FileCheckConst
|
|
10
10
|
|
|
@@ -30,20 +30,20 @@ class DataWriter:
|
|
|
30
30
|
return
|
|
31
31
|
is_exists = os.path.exists(file_path)
|
|
32
32
|
append = "a+" if is_exists else "w+"
|
|
33
|
-
with
|
|
34
|
-
os.open(file_path, Const.WRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), append, newline=""
|
|
35
|
-
) as csv_file:
|
|
33
|
+
with FileOpen(file_path, append) as csv_file:
|
|
36
34
|
spawn_writer = csv.writer(csv_file)
|
|
37
35
|
if not is_exists:
|
|
38
36
|
spawn_writer.writerow(result_header)
|
|
39
37
|
spawn_writer.writerows([result,])
|
|
38
|
+
is_new_file = not is_exists
|
|
39
|
+
if is_new_file:
|
|
40
|
+
change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
40
41
|
|
|
41
42
|
def initialize_json_file(self, **kwargs):
|
|
42
43
|
kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
|
|
43
|
-
with
|
|
44
|
-
os.open(self.dump_file_path, Const.OVERWRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), 'w'
|
|
45
|
-
) as f:
|
|
44
|
+
with FileOpen(self.dump_file_path, 'w') as f:
|
|
46
45
|
json.dump(kwargs, f)
|
|
46
|
+
change_mode(self.dump_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
47
47
|
|
|
48
48
|
if os.path.exists(self.stack_file_path):
|
|
49
49
|
os.remove(self.stack_file_path)
|
|
@@ -83,7 +83,7 @@ class DataWriter:
|
|
|
83
83
|
def write_data_json(self, file_path):
|
|
84
84
|
logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
|
|
85
85
|
if Path(file_path).exists() and os.path.getsize(file_path) > 0:
|
|
86
|
-
with
|
|
86
|
+
with FileOpen(file_path, "r+") as f:
|
|
87
87
|
fcntl.flock(f, fcntl.LOCK_EX)
|
|
88
88
|
data_to_write = json.load(f)
|
|
89
89
|
fcntl.flock(f, fcntl.LOCK_UN)
|
|
@@ -91,7 +91,7 @@ class DataWriter:
|
|
|
91
91
|
self.init_json['data_path'] = self.dump_tensor_data_dir
|
|
92
92
|
data_to_write = self.init_json
|
|
93
93
|
data_to_write[Const.DATA].update(self.cache_data[Const.DATA])
|
|
94
|
-
with
|
|
94
|
+
with FileOpen(file_path, 'w+') as f:
|
|
95
95
|
fcntl.flock(f, fcntl.LOCK_EX)
|
|
96
96
|
json.dump(data_to_write, f, indent=1)
|
|
97
97
|
fcntl.flock(f, fcntl.LOCK_UN)
|
|
@@ -99,13 +99,13 @@ class DataWriter:
|
|
|
99
99
|
self.cache_data[Const.DATA].clear()
|
|
100
100
|
|
|
101
101
|
def write_stack_info_json(self, file_path):
|
|
102
|
-
with
|
|
102
|
+
with FileOpen(file_path, 'w+') as f:
|
|
103
103
|
fcntl.flock(f, fcntl.LOCK_EX)
|
|
104
104
|
json.dump(self.cache_stack, f, indent=1)
|
|
105
105
|
fcntl.flock(f, fcntl.LOCK_UN)
|
|
106
106
|
|
|
107
107
|
def write_construct_info_json(self, file_path):
|
|
108
|
-
with
|
|
108
|
+
with FileOpen(file_path, 'w+') as f:
|
|
109
109
|
fcntl.flock(f, fcntl.LOCK_EX)
|
|
110
110
|
json.dump(self.cache_construct, f, indent=1)
|
|
111
111
|
fcntl.flock(f, fcntl.LOCK_UN)
|
|
File without changes
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
|
|
2
|
+
class GradConst:
|
|
3
|
+
|
|
4
|
+
FRAMEWORKS = {"PyTorch", "MindSpore"}
|
|
5
|
+
PYTORCH = "PyTorch"
|
|
6
|
+
MindSpore = "MindSpore"
|
|
7
|
+
|
|
8
|
+
GRAD_FILE_SUFFIX = {"npy", "pt"}
|
|
9
|
+
NPY_SUFFIX = "npy"
|
|
10
|
+
PT_SUFFIX = "pt"
|
|
11
|
+
|
|
12
|
+
# for callback
|
|
13
|
+
CURRENT_STEP = "current_step"
|
|
14
|
+
|
|
15
|
+
PARAM_LIST = "param_list"
|
|
16
|
+
RANK = "rank"
|
|
17
|
+
STEP = "step"
|
|
18
|
+
BOUNDS = "bounds"
|
|
19
|
+
OUTPUT_PATH = "output_path"
|
|
20
|
+
|
|
21
|
+
# level const
|
|
22
|
+
LEVEL = "level"
|
|
23
|
+
LEVEL0 = "L0"
|
|
24
|
+
LEVEL1 = "L1"
|
|
25
|
+
LEVEL2 = "L2"
|
|
26
|
+
SUPPORTED_LEVEL = {"L0", "L1", "L2"}
|
|
27
|
+
|
|
28
|
+
# numpy coding
|
|
29
|
+
STEP_IDX = 0
|
|
30
|
+
SHAPE_DIM_IDX = 4
|
|
31
|
+
MAX_SIZE = 10 * 1024 * 1024 * 1024
|
|
32
|
+
|
|
33
|
+
# direction suffix
|
|
34
|
+
DIR_SUFFIX = "dir.npy"
|
|
35
|
+
|
|
36
|
+
# file safty
|
|
37
|
+
DATA_DIR_AUTHORITY = 0o750
|
|
38
|
+
DATA_FILE_AUTHORITY = 0o640
|
|
39
|
+
DIRECTORY_LENGTH = 4096
|
|
40
|
+
FILE_NAME_LENGTH = 255
|
|
41
|
+
FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$"
|
|
42
|
+
PARAM_VALID_PATTERN = r"^[a-zA-Z0-9_.]+$"
|
|
43
|
+
DIR = "dir"
|
|
44
|
+
FILE = "file"
|
|
45
|
+
|
|
46
|
+
STEP_FINISH = "step_finish"
|
|
47
|
+
|
|
48
|
+
SUMMARY = "summary"
|
|
49
|
+
|
|
50
|
+
# csv header entry
|
|
51
|
+
MD5 = "MD5"
|
|
52
|
+
DISTRIBUTION = "distribution"
|
|
53
|
+
SHAPE = "shape"
|
|
54
|
+
MAX = "max"
|
|
55
|
+
MIN = "min"
|
|
56
|
+
NORM = "norm"
|
|
57
|
+
|
|
58
|
+
level_adp = {
|
|
59
|
+
"L0": {
|
|
60
|
+
"header": [GradConst.MD5, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
|
|
61
|
+
"have_grad_direction": False
|
|
62
|
+
},
|
|
63
|
+
"L1": {
|
|
64
|
+
"header": [GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
|
|
65
|
+
"have_grad_direction": True
|
|
66
|
+
},
|
|
67
|
+
"L2": {
|
|
68
|
+
"header": [GradConst.DISTRIBUTION, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
|
|
69
|
+
"have_grad_direction": True
|
|
70
|
+
},
|
|
71
|
+
}
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
|
|
8
|
+
from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create
|
|
9
|
+
from msprobe.core.common.file_check import create_directory
|
|
10
|
+
from msprobe.core.common.log import logger
|
|
11
|
+
from msprobe.core.common.utils import remove_path, write_csv, load_npy
|
|
12
|
+
from msprobe.core.grad_probe.constant import GradConst
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class GradComparator:
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def _get_grad_weight_order(path1, path2):
|
|
19
|
+
for summary_file in os.listdir(path1):
|
|
20
|
+
if not summary_file.endswith(".csv"):
|
|
21
|
+
continue
|
|
22
|
+
if not os.path.exists(os.path.join(path2, summary_file)):
|
|
23
|
+
continue
|
|
24
|
+
summary_csv = pd.read_csv(os.path.join(path1, summary_file))
|
|
25
|
+
return summary_csv["param_name"]
|
|
26
|
+
raise RuntimeError("no matched grad_summary.csv for comparison, please dump data in same configuration")
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def _get_name_matched_grad_file(param_name, grad_files):
|
|
30
|
+
for grad_file in grad_files:
|
|
31
|
+
if param_name == grad_file[:grad_file.rfind('.')]:
|
|
32
|
+
return grad_file
|
|
33
|
+
raise RuntimeError("no matched grad_file for comparison, please dump data in same configuration")
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def compare_distributed(cls, path1: str, path2: str, output_dir: str):
|
|
37
|
+
ranks = cls._get_matched_dirs(path1, path2, "rank")
|
|
38
|
+
logger.info(f"the following ranks will be compared: {ranks}")
|
|
39
|
+
if not ranks:
|
|
40
|
+
raise RuntimeError("no matched ranks for comparison, please dump data in same configuration")
|
|
41
|
+
if not os.path.isdir(output_dir):
|
|
42
|
+
create_directory(output_dir)
|
|
43
|
+
for rank in tqdm(ranks, desc="rank"):
|
|
44
|
+
logger.info(f"now comparing rank {rank}:")
|
|
45
|
+
cls.compare(os.path.join(path1, f"rank{rank}"),
|
|
46
|
+
os.path.join(path2, f"rank{rank}"),
|
|
47
|
+
os.path.join(output_dir, f"rank{rank}"))
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def compare(cls, path1: str, path2: str, output_dir: str):
|
|
51
|
+
steps = cls._get_matched_dirs(path1, path2, "step")
|
|
52
|
+
if not steps:
|
|
53
|
+
raise RuntimeError("no matched steps for comparison, please dump data in same configuration")
|
|
54
|
+
similarities = cls._calculate_separated_similarities(path1, path2, steps)
|
|
55
|
+
if not os.path.isdir(output_dir):
|
|
56
|
+
create_directory(output_dir)
|
|
57
|
+
cls._save_similarities(similarities, steps, output_dir)
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def _get_matched_dirs(cls, path1: str, path2: str, dir_prefix):
|
|
61
|
+
check_file_or_directory_path(path1, isdir=True)
|
|
62
|
+
check_file_or_directory_path(path2, isdir=True)
|
|
63
|
+
dirs = []
|
|
64
|
+
for dir_name in os.listdir(path1):
|
|
65
|
+
index = dir_name.replace(dir_prefix, "", 1)
|
|
66
|
+
if not dir_name.startswith(dir_prefix) or not index.isdigit():
|
|
67
|
+
continue
|
|
68
|
+
|
|
69
|
+
folder2 = os.path.join(path2, dir_name)
|
|
70
|
+
if not os.path.isdir(folder2):
|
|
71
|
+
continue
|
|
72
|
+
dirs.append(int(index))
|
|
73
|
+
dirs = sorted(dirs)
|
|
74
|
+
return dirs
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def _save_similarities(cls, similarities: List[float], steps: List[int], output_dir: str):
|
|
78
|
+
if not similarities:
|
|
79
|
+
raise ValueError(f"length of similarities is 0")
|
|
80
|
+
result = [['step'] + [str(step) for step in steps]]
|
|
81
|
+
for key, value in tqdm(similarities.items(), desc="save similarities (by param)"):
|
|
82
|
+
if len(value) != len(steps):
|
|
83
|
+
raise RuntimeError(f"similarities length of {key}:{len(value)} not equal steps:{len(steps)}")
|
|
84
|
+
plt.plot(steps, value)
|
|
85
|
+
plt.xlabel('steps')
|
|
86
|
+
plt.ylabel('similarities')
|
|
87
|
+
plt.title(f'{key}_similarities')
|
|
88
|
+
picture_dir = os.path.join(output_dir, "similarities_picture")
|
|
89
|
+
if not os.path.isdir(picture_dir):
|
|
90
|
+
create_directory(picture_dir)
|
|
91
|
+
fig_save_path = os.path.join(picture_dir, f"{key}_similarities.png")
|
|
92
|
+
|
|
93
|
+
check_path_before_create(fig_save_path)
|
|
94
|
+
try:
|
|
95
|
+
plt.savefig(fig_save_path)
|
|
96
|
+
except Exception as e:
|
|
97
|
+
raise RuntimeError(f"save plt figure {fig_save_path} failed") from e
|
|
98
|
+
plt.close()
|
|
99
|
+
|
|
100
|
+
result.append([key] + value)
|
|
101
|
+
result_csv_path = os.path.join(output_dir, "similarities.csv")
|
|
102
|
+
if os.path.exists(result_csv_path):
|
|
103
|
+
logger.warning(f"{result_csv_path} will be recoverd")
|
|
104
|
+
remove_path(result_csv_path)
|
|
105
|
+
write_csv(result, result_csv_path)
|
|
106
|
+
|
|
107
|
+
@classmethod
|
|
108
|
+
def _calculate_separated_similarities(cls, path1, path2, steps):
|
|
109
|
+
similarities = {}
|
|
110
|
+
logger.info(f"{len(steps)} steps will be compared")
|
|
111
|
+
grad_weight_order = cls._get_grad_weight_order(path1, path2)
|
|
112
|
+
for step in tqdm(steps, desc="culculate similarities (by step)"):
|
|
113
|
+
grad_files = cls._get_matched_grad_files(path1, path2, step)
|
|
114
|
+
same_count_summary = 0
|
|
115
|
+
total_count_summary = 0
|
|
116
|
+
for grad_name in grad_weight_order:
|
|
117
|
+
grad_file = cls._get_name_matched_grad_file(grad_name, grad_files)
|
|
118
|
+
grad1 = os.path.join(path1, f"step{step}", grad_file)
|
|
119
|
+
grad2 = os.path.join(path2, f"step{step}", grad_file)
|
|
120
|
+
same_count, total_count = cls._calculate_similarity(grad1, grad2)
|
|
121
|
+
same_count_summary += same_count
|
|
122
|
+
total_count_summary += total_count
|
|
123
|
+
idx = grad_file.rfind(".")
|
|
124
|
+
param_name = grad_file[:idx]
|
|
125
|
+
if param_name not in similarities:
|
|
126
|
+
similarities[param_name] = []
|
|
127
|
+
if total_count == 0:
|
|
128
|
+
similarities[param_name].append(0)
|
|
129
|
+
else:
|
|
130
|
+
similarities[param_name].append(same_count / total_count)
|
|
131
|
+
if GradConst.SUMMARY not in similarities:
|
|
132
|
+
similarities[GradConst.SUMMARY] = []
|
|
133
|
+
if total_count_summary == 0:
|
|
134
|
+
similarities[GradConst.SUMMARY].append(0)
|
|
135
|
+
else:
|
|
136
|
+
similarities[GradConst.SUMMARY].append(same_count_summary / total_count_summary)
|
|
137
|
+
return similarities
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def _get_matched_grad_files(cls, path1: str, path2: str, step: int):
|
|
141
|
+
path1 = os.path.join(path1, f"step{step}")
|
|
142
|
+
path2 = os.path.join(path2, f"step{step}")
|
|
143
|
+
check_file_or_directory_path(path1, isdir=True)
|
|
144
|
+
check_file_or_directory_path(path2, isdir=True)
|
|
145
|
+
grad_files = []
|
|
146
|
+
for grad_file in os.listdir(path1):
|
|
147
|
+
splits = grad_file.split('.')
|
|
148
|
+
if len(splits) < 1 or splits[-1] not in GradConst.GRAD_FILE_SUFFIX:
|
|
149
|
+
continue
|
|
150
|
+
folder2 = os.path.join(path2, grad_file)
|
|
151
|
+
if not os.path.exists(folder2):
|
|
152
|
+
continue
|
|
153
|
+
grad_files.append(grad_file)
|
|
154
|
+
return sorted(grad_files)
|
|
155
|
+
|
|
156
|
+
@classmethod
|
|
157
|
+
def _calculate_similarity(cls, grad_file1: str, grad_file2: str):
|
|
158
|
+
npy1, npy2 = cls._load_grad_files(grad_file1, grad_file2)
|
|
159
|
+
same_count = (npy1 == npy2).sum()
|
|
160
|
+
total_count = npy1.size
|
|
161
|
+
return same_count, total_count
|
|
162
|
+
|
|
163
|
+
@classmethod
|
|
164
|
+
def _load_grad_files(cls, grad_file1: str, grad_file2: str):
|
|
165
|
+
grad1 = load_npy(grad_file1)
|
|
166
|
+
grad2 = load_npy(grad_file2)
|
|
167
|
+
if grad1.shape != grad2.shape:
|
|
168
|
+
raise RuntimeError(f"tensor shape is not equal: {grad_file1}, {grad_file2}")
|
|
169
|
+
if grad1.dtype != bool:
|
|
170
|
+
raise TypeError(f"tensor type is not bool: {grad_file1}")
|
|
171
|
+
if grad2.dtype != bool:
|
|
172
|
+
raise TypeError(f"tensor type is not bool: {grad_file2}")
|
|
173
|
+
return grad1, grad2
|
|
174
|
+
|
|
175
|
+
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from msprobe.core.grad_probe.constant import GradConst
|
|
3
|
+
from msprobe.core.common.log import logger
|
|
4
|
+
from msprobe.core.common.utils import write_csv
|
|
5
|
+
|
|
6
|
+
def data_in_list_target(data, lst):
|
|
7
|
+
return not lst or len(lst) == 0 or data in lst
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def check_numeral_list_ascend(lst):
|
|
11
|
+
if any(not isinstance(item, (int, float)) for item in lst):
|
|
12
|
+
raise Exception("The input list should only contain numbers")
|
|
13
|
+
if lst != sorted(lst):
|
|
14
|
+
raise Exception("The input list should be ascending")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def check_param(param_name):
|
|
18
|
+
if not re.match(GradConst.PARAM_VALID_PATTERN, param_name):
|
|
19
|
+
raise RuntimeError("The parameter name contains special characters.")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def check_str(string, variable_name):
|
|
23
|
+
if not isinstance(string, str):
|
|
24
|
+
raise ValueError(f'The variable: "{variable_name}" is not a string.')
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ListCache(list):
|
|
28
|
+
threshold = 1000
|
|
29
|
+
|
|
30
|
+
def __init__(self, *args):
|
|
31
|
+
super().__init__(*args)
|
|
32
|
+
self._output_file = None
|
|
33
|
+
|
|
34
|
+
def __del__(self):
|
|
35
|
+
self.flush()
|
|
36
|
+
|
|
37
|
+
def flush(self):
|
|
38
|
+
if len(self) == 0:
|
|
39
|
+
return
|
|
40
|
+
if not self._output_file:
|
|
41
|
+
logger.warning("dumpfile path is not setted")
|
|
42
|
+
write_csv(self, self._output_file)
|
|
43
|
+
logger.info(f"write {len(self)} items to {self._output_file}.")
|
|
44
|
+
self.clear()
|
|
45
|
+
|
|
46
|
+
def append(self, data):
|
|
47
|
+
list.append(self, data)
|
|
48
|
+
if len(self) >= ListCache.threshold:
|
|
49
|
+
self.flush()
|
|
50
|
+
|
|
51
|
+
def set_output_file(self, output_file):
|
|
52
|
+
self._output_file = output_file
|