mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- msprobe/README.md +46 -16
- msprobe/__init__.py +16 -1
- msprobe/config.json +0 -2
- msprobe/core/advisor/advisor.py +8 -8
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +64 -3
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +54 -9
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +21 -11
- msprobe/core/common/utils.py +153 -167
- msprobe/core/common_config.py +18 -25
- msprobe/core/compare/acc_compare.py +209 -36
- msprobe/core/compare/check.py +102 -17
- msprobe/core/compare/compare_cli.py +21 -1
- msprobe/core/compare/highlight.py +41 -5
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +21 -6
- msprobe/core/compare/utils.py +82 -48
- msprobe/core/data_dump/data_collector.py +31 -32
- msprobe/core/data_dump/data_processor/base.py +45 -22
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
- msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +32 -16
- msprobe/core/grad_probe/constant.py +4 -0
- msprobe/core/grad_probe/grad_compare.py +2 -3
- msprobe/core/grad_probe/utils.py +16 -3
- msprobe/docs/01.installation.md +19 -9
- msprobe/docs/02.config_introduction.md +52 -80
- msprobe/docs/03.config_examples.md +3 -13
- msprobe/docs/04.acl_config_examples.md +11 -9
- msprobe/docs/05.data_dump_PyTorch.md +140 -12
- msprobe/docs/06.data_dump_MindSpore.md +47 -5
- msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/docs/17.grad_probe.md +14 -16
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
- msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
- msprobe/mindspore/cell_processor.py +27 -3
- msprobe/mindspore/common/const.py +2 -0
- msprobe/mindspore/common/utils.py +18 -2
- msprobe/mindspore/compare/distributed_compare.py +9 -22
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +173 -35
- msprobe/mindspore/compare/ms_graph_compare.py +27 -11
- msprobe/mindspore/debugger/debugger_config.py +16 -13
- msprobe/mindspore/debugger/precision_debugger.py +37 -13
- msprobe/mindspore/dump/dump_tool_factory.py +16 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +41 -17
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
- msprobe/mindspore/free_benchmark/common/utils.py +19 -5
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
- msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
- msprobe/mindspore/grad_probe/global_context.py +18 -8
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/service.py +42 -123
- msprobe/pytorch/__init__.py +20 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +20 -5
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +26 -11
- msprobe/pytorch/common/utils.py +40 -35
- msprobe/pytorch/compare/distributed_compare.py +11 -11
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +38 -6
- msprobe/pytorch/debugger/debugger_config.py +52 -39
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/enums.py +28 -0
- msprobe/pytorch/free_benchmark/common/params.py +15 -0
- msprobe/pytorch/free_benchmark/common/utils.py +17 -1
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +10 -11
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +17 -2
- msprobe/pytorch/online_dispatch/compare.py +11 -12
- msprobe/pytorch/online_dispatch/single_compare.py +7 -7
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
- msprobe/pytorch/online_dispatch/utils.py +1 -4
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +9 -10
- msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
- msprobe/pytorch/parse_tool/lib/utils.py +28 -24
- msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
- msprobe/pytorch/pt_config.py +167 -38
- msprobe/pytorch/service.py +97 -32
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,2 +1,17 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from .parse_json import parse_json_info_forward_backward
|
|
2
17
|
from .utils import seed_all
|
msprobe/pytorch/common/log.py
CHANGED
|
@@ -1,9 +1,21 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
6
16
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
17
|
+
from msprobe.core.common.log import BaseLogger
|
|
18
|
+
from msprobe.pytorch.common.utils import get_rank_if_initialized
|
|
7
19
|
|
|
8
20
|
|
|
9
21
|
class PyTorchLogger(BaseLogger):
|
|
@@ -18,4 +30,4 @@ class PyTorchLogger(BaseLogger):
|
|
|
18
30
|
return current_rank
|
|
19
31
|
|
|
20
32
|
|
|
21
|
-
logger = PyTorchLogger()
|
|
33
|
+
logger = PyTorchLogger()
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import json
|
|
2
17
|
|
|
3
18
|
from msprobe.core.common.exceptions import ParseJsonException
|
|
@@ -5,14 +20,6 @@ from msprobe.core.common.file_utils import FileOpen
|
|
|
5
20
|
|
|
6
21
|
|
|
7
22
|
def parse_json_info_forward_backward(json_path):
|
|
8
|
-
def parse_data_name_with_pattern(data_name, pattern):
|
|
9
|
-
name_struct = data_name.split('.')
|
|
10
|
-
if not name_struct[-1] == pattern:
|
|
11
|
-
raise ParseJsonException(ParseJsonException.UnexpectedNameStruct,
|
|
12
|
-
f"{data_name} in file {json_path}")
|
|
13
|
-
api_name = '.'.join(name_struct[:-1])
|
|
14
|
-
return api_name
|
|
15
|
-
|
|
16
23
|
with FileOpen(json_path, 'r') as f:
|
|
17
24
|
dump_json = json.load(f)
|
|
18
25
|
|
|
@@ -27,13 +34,21 @@ def parse_json_info_forward_backward(json_path):
|
|
|
27
34
|
if "Module" in data_name:
|
|
28
35
|
continue
|
|
29
36
|
if "forward" in data_name:
|
|
30
|
-
api_name = parse_data_name_with_pattern(data_name, "forward")
|
|
37
|
+
api_name = parse_data_name_with_pattern(data_name, "forward", json_path)
|
|
31
38
|
forward_data.update({api_name: data_item})
|
|
32
39
|
elif "backward" in data_name:
|
|
33
|
-
api_name = parse_data_name_with_pattern(data_name, "backward")
|
|
40
|
+
api_name = parse_data_name_with_pattern(data_name, "backward", json_path)
|
|
34
41
|
backward_data.update({api_name: data_item})
|
|
35
42
|
else:
|
|
36
43
|
raise ParseJsonException(ParseJsonException.UnexpectedNameStruct,
|
|
37
|
-
|
|
44
|
+
f"{data_name} in file {json_path}.")
|
|
38
45
|
|
|
39
46
|
return forward_data, backward_data, real_data_path
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def parse_data_name_with_pattern(data_name, pattern, json_path):
|
|
50
|
+
name_struct = data_name.split('.')
|
|
51
|
+
if not name_struct[-1] == pattern:
|
|
52
|
+
raise ParseJsonException(ParseJsonException.UnexpectedNameStruct, f"{data_name} in file {json_path}")
|
|
53
|
+
api_name = '.'.join(name_struct[:-1])
|
|
54
|
+
return api_name
|
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,20 +12,22 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
|
|
15
|
+
|
|
17
16
|
import io
|
|
18
17
|
import os
|
|
19
18
|
import random
|
|
20
19
|
import stat
|
|
20
|
+
from functools import wraps
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
21
23
|
import torch
|
|
22
24
|
import torch.distributed as dist
|
|
23
|
-
import numpy as np
|
|
24
|
-
from functools import wraps
|
|
25
25
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
26
|
-
from msprobe.core.common.log import logger
|
|
27
26
|
from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
|
|
28
27
|
check_file_or_directory_path, check_path_before_create)
|
|
29
|
-
|
|
28
|
+
from msprobe.core.common.log import logger
|
|
29
|
+
from msprobe.core.common.utils import check_seed_all
|
|
30
|
+
from packaging import version
|
|
30
31
|
|
|
31
32
|
try:
|
|
32
33
|
import torch_npu
|
|
@@ -35,10 +36,8 @@ except ImportError:
|
|
|
35
36
|
else:
|
|
36
37
|
is_gpu = False
|
|
37
38
|
|
|
38
|
-
|
|
39
39
|
torch_without_guard_version = torch.__version__ >= '2.1'
|
|
40
40
|
|
|
41
|
-
|
|
42
41
|
if not is_gpu and not torch_without_guard_version:
|
|
43
42
|
from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard
|
|
44
43
|
|
|
@@ -46,7 +45,6 @@ npu_distributed_api = ['isend', 'irecv']
|
|
|
46
45
|
|
|
47
46
|
|
|
48
47
|
def parameter_adapter(func):
|
|
49
|
-
|
|
50
48
|
def handle_masked_select(input_tensor, indices):
|
|
51
49
|
masked_select_func = getattr(torch._C._VariableFunctionsClass, "masked_select")
|
|
52
50
|
if input_tensor.dtype == torch.bfloat16:
|
|
@@ -80,17 +78,19 @@ def parameter_adapter(func):
|
|
|
80
78
|
if self.op_name_ == "__eq__" and args[1] is None:
|
|
81
79
|
return False
|
|
82
80
|
return func(self, *args, **kwargs)
|
|
81
|
+
|
|
83
82
|
return inner
|
|
84
83
|
|
|
85
84
|
|
|
86
85
|
def torch_device_guard(func):
|
|
87
86
|
if is_gpu or torch_without_guard_version:
|
|
88
87
|
return func
|
|
89
|
-
# Parse args/kwargs matched torch.device objects
|
|
90
88
|
|
|
89
|
+
# Parse args/kwargs matched torch.device objects
|
|
91
90
|
@torch_npu_device_guard
|
|
92
91
|
def wrapper(*args, **kwargs):
|
|
93
92
|
return func(*args, **kwargs)
|
|
93
|
+
|
|
94
94
|
return wrapper
|
|
95
95
|
|
|
96
96
|
|
|
@@ -105,20 +105,28 @@ def get_rank_if_initialized():
|
|
|
105
105
|
|
|
106
106
|
|
|
107
107
|
def seed_all(seed=1234, mode=False):
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
torch.cuda
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
torch.
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
108
|
+
check_seed_all(seed, mode)
|
|
109
|
+
try:
|
|
110
|
+
random.seed(seed)
|
|
111
|
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
112
|
+
np.random.seed(seed)
|
|
113
|
+
torch.manual_seed(seed)
|
|
114
|
+
cuda_version = torch.version.cuda
|
|
115
|
+
if cuda_version is not None and version.parse(cuda_version) >= version.parse("10.2"):
|
|
116
|
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
|
117
|
+
os.environ['HCCL_DETERMINISTIC'] = str(mode)
|
|
118
|
+
torch.use_deterministic_algorithms(mode)
|
|
119
|
+
if is_gpu:
|
|
120
|
+
torch.cuda.manual_seed_all(seed)
|
|
121
|
+
torch.cuda.manual_seed(seed)
|
|
122
|
+
torch.backends.cudnn.deterministic = True
|
|
123
|
+
torch.backends.cudnn.enable = False
|
|
124
|
+
torch.backends.cudnn.benchmark = False
|
|
125
|
+
else:
|
|
126
|
+
torch_npu.npu.manual_seed_all(seed)
|
|
127
|
+
torch_npu.npu.manual_seed(seed)
|
|
128
|
+
except Exception as e:
|
|
129
|
+
logger.error(f"There is an unexpected error while determinating randomness. {e}")
|
|
122
130
|
|
|
123
131
|
|
|
124
132
|
class Const:
|
|
@@ -191,10 +199,7 @@ class Const:
|
|
|
191
199
|
ENV_ENABLE = "1"
|
|
192
200
|
ENV_DISABLE = "0"
|
|
193
201
|
|
|
194
|
-
MAX_SEED_VALUE = 2**32 - 1
|
|
195
|
-
|
|
196
|
-
INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
|
|
197
|
-
"_reduce_scatter_base", "_all_gather_base", "all_to_all_single"]
|
|
202
|
+
MAX_SEED_VALUE = 2 ** 32 - 1
|
|
198
203
|
|
|
199
204
|
TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"]
|
|
200
205
|
LEVEL_LIST = ["L0", "L1", "L2", "mix"]
|
|
@@ -257,7 +262,7 @@ def print_rank_0(message):
|
|
|
257
262
|
logger.info(message)
|
|
258
263
|
else:
|
|
259
264
|
logger.info(message)
|
|
260
|
-
|
|
265
|
+
|
|
261
266
|
|
|
262
267
|
def load_pt(pt_path, to_cpu=False):
|
|
263
268
|
pt_path = os.path.realpath(pt_path)
|
|
@@ -279,8 +284,8 @@ def save_pt(tensor, filepath):
|
|
|
279
284
|
torch.save(tensor, filepath)
|
|
280
285
|
except Exception as e:
|
|
281
286
|
logger.error("Save pt file failed, please check according possible error causes: "
|
|
282
|
-
|
|
283
|
-
|
|
287
|
+
"1. out of disk space or disk error, "
|
|
288
|
+
"2. no permission to write files, etc.")
|
|
284
289
|
raise RuntimeError(f"save pt file {filepath} failed") from e
|
|
285
290
|
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
286
291
|
|
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2019-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,14 +12,13 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
|
|
15
|
+
|
|
17
16
|
import os
|
|
18
17
|
from msprobe.core.common.utils import CompareException, check_compare_param, \
|
|
19
18
|
check_configuration_param, task_dumppath_get
|
|
20
19
|
from msprobe.core.common.file_utils import create_directory
|
|
21
20
|
from msprobe.core.common.exceptions import FileCheckException
|
|
22
21
|
from msprobe.pytorch.common.log import logger
|
|
23
|
-
from msprobe.core.common.const import Const
|
|
24
22
|
from msprobe.pytorch.compare.pt_compare import PTComparator
|
|
25
23
|
from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
|
|
26
24
|
|
|
@@ -55,12 +53,14 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
|
55
53
|
}
|
|
56
54
|
try:
|
|
57
55
|
summary_compare, md5_compare = task_dumppath_get(dump_result_param)
|
|
58
|
-
check_configuration_param(stack_mode, auto_analyze, fuzzy_match
|
|
56
|
+
check_configuration_param(stack_mode, auto_analyze, fuzzy_match,
|
|
57
|
+
dump_result_param.get('is_print_compare_log', True))
|
|
59
58
|
create_directory(output_path)
|
|
60
|
-
check_compare_param(dump_result_param, output_path,
|
|
59
|
+
check_compare_param(dump_result_param, output_path,
|
|
60
|
+
summary_compare=summary_compare, md5_compare=md5_compare)
|
|
61
61
|
except (CompareException, FileCheckException) as error:
|
|
62
62
|
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
63
63
|
raise CompareException(error.code) from error
|
|
64
64
|
pt_comparator = PTComparator()
|
|
65
|
-
pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}',
|
|
66
|
-
|
|
65
|
+
pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}',
|
|
66
|
+
summary_compare=summary_compare, md5_compare=md5_compare, **kwargs)
|
msprobe/pytorch/compare/match.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
from msprobe.core.common.utils import CompareException
|
|
3
18
|
from msprobe.core.common.file_utils import load_yaml
|
|
@@ -1,17 +1,48 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os.path
|
|
2
17
|
import torch
|
|
3
18
|
from msprobe.core.common.const import FileCheckConst
|
|
4
19
|
from msprobe.pytorch.common.log import logger
|
|
5
20
|
from msprobe.core.common.exceptions import FileCheckException
|
|
6
21
|
from msprobe.core.compare.acc_compare import Comparator
|
|
7
|
-
from msprobe.core.common.utils import check_configuration_param, task_dumppath_get, check_compare_param,
|
|
8
|
-
|
|
22
|
+
from msprobe.core.common.utils import check_configuration_param, task_dumppath_get, check_compare_param, \
|
|
23
|
+
CompareException
|
|
24
|
+
from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml
|
|
9
25
|
from msprobe.pytorch.common.utils import load_pt
|
|
10
26
|
|
|
11
27
|
|
|
12
28
|
class PTComparator (Comparator):
|
|
13
|
-
def __init__(self):
|
|
29
|
+
def __init__(self, data_mapping=None):
|
|
14
30
|
self.frame_name = PTComparator.__name__
|
|
31
|
+
self.data_mapping = data_mapping
|
|
32
|
+
if isinstance(self.data_mapping, str) or self.data_mapping is None:
|
|
33
|
+
self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
|
|
34
|
+
elif isinstance(self.data_mapping, dict):
|
|
35
|
+
self.data_mapping_dict = self.data_mapping
|
|
36
|
+
else:
|
|
37
|
+
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
38
|
+
f"{type(self.data_mapping)}")
|
|
39
|
+
|
|
40
|
+
def load_mapping_file(self, mapping_file):
|
|
41
|
+
if isinstance(mapping_file, str):
|
|
42
|
+
mapping_dict = load_yaml(mapping_file)
|
|
43
|
+
else:
|
|
44
|
+
mapping_dict = {}
|
|
45
|
+
return mapping_dict
|
|
15
46
|
|
|
16
47
|
def read_npy_data(self, dir_path, file_name):
|
|
17
48
|
data_path = os.path.join(dir_path, file_name)
|
|
@@ -35,16 +66,17 @@ class PTComparator (Comparator):
|
|
|
35
66
|
return data_value
|
|
36
67
|
|
|
37
68
|
|
|
38
|
-
def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False):
|
|
69
|
+
def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False, **kwargs):
|
|
39
70
|
try:
|
|
40
71
|
summary_compare, md5_compare = task_dumppath_get(input_param)
|
|
41
|
-
check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
|
|
72
|
+
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
|
|
42
73
|
create_directory(output_path)
|
|
43
74
|
check_compare_param(input_param, output_path, summary_compare, md5_compare)
|
|
75
|
+
data_mapping = kwargs.get('data_mapping', None)
|
|
44
76
|
except (CompareException, FileCheckException) as error:
|
|
45
77
|
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
46
78
|
raise CompareException(error.code) from error
|
|
47
|
-
pt_comparator = PTComparator()
|
|
79
|
+
pt_comparator = PTComparator(data_mapping)
|
|
48
80
|
pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
|
|
49
81
|
auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
|
|
50
82
|
md5_compare=md5_compare)
|
|
@@ -1,6 +1,23 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
3
18
|
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
20
|
+
from msprobe.pytorch.common.log import logger
|
|
4
21
|
|
|
5
22
|
|
|
6
23
|
class DebuggerConfig:
|
|
@@ -10,8 +27,6 @@ class DebuggerConfig:
|
|
|
10
27
|
self.rank = common_config.rank if common_config.rank else []
|
|
11
28
|
self.step = common_config.step if common_config.step else []
|
|
12
29
|
self.level = level or common_config.level or "L1"
|
|
13
|
-
self.seed = common_config.seed if common_config.seed else 1234
|
|
14
|
-
self.is_deterministic = common_config.is_deterministic
|
|
15
30
|
self.enable_dataloader = common_config.enable_dataloader
|
|
16
31
|
self.scope = task_config.scope if task_config.scope else []
|
|
17
32
|
self.list = task_config.list if task_config.list else []
|
|
@@ -25,15 +40,15 @@ class DebuggerConfig:
|
|
|
25
40
|
self.framework = Const.PT_FRAMEWORK
|
|
26
41
|
|
|
27
42
|
if self.task == Const.FREE_BENCHMARK:
|
|
28
|
-
self.fuzz_device = task_config.fuzz_device
|
|
29
|
-
self.handler_type = task_config.handler_type
|
|
30
|
-
self.pert_mode = task_config.pert_mode
|
|
31
|
-
self.fuzz_level = task_config.fuzz_level
|
|
32
|
-
self.fuzz_stage = task_config.fuzz_stage
|
|
43
|
+
self.fuzz_device = task_config.fuzz_device
|
|
44
|
+
self.handler_type = task_config.handler_type
|
|
45
|
+
self.pert_mode = task_config.pert_mode
|
|
46
|
+
self.fuzz_level = task_config.fuzz_level
|
|
47
|
+
self.fuzz_stage = task_config.fuzz_stage
|
|
33
48
|
self.preheat_config = {
|
|
34
|
-
"if_preheat": task_config.if_preheat
|
|
35
|
-
"preheat_step": task_config.preheat_step
|
|
36
|
-
"max_sample": task_config.max_sample
|
|
49
|
+
"if_preheat": task_config.if_preheat,
|
|
50
|
+
"preheat_step": task_config.preheat_step,
|
|
51
|
+
"max_sample": task_config.max_sample
|
|
37
52
|
}
|
|
38
53
|
|
|
39
54
|
self.online_run_ut = False
|
|
@@ -46,8 +61,7 @@ class DebuggerConfig:
|
|
|
46
61
|
self.port = task_config.port if task_config.port else -1
|
|
47
62
|
|
|
48
63
|
self.check()
|
|
49
|
-
|
|
50
|
-
self.step.sort()
|
|
64
|
+
|
|
51
65
|
if self.level == "L2":
|
|
52
66
|
if not self.scope or not isinstance(self.scope, list) or len(self.scope) != 1:
|
|
53
67
|
raise ValueError("scope must be configured as a list with one api name")
|
|
@@ -58,38 +72,37 @@ class DebuggerConfig:
|
|
|
58
72
|
for index, scope_spec in enumerate(self.scope):
|
|
59
73
|
self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
|
|
60
74
|
self.backward_input[self.scope[index]] = self.backward_input_list[index]
|
|
61
|
-
seed_all(self.seed, self.is_deterministic)
|
|
62
75
|
|
|
63
76
|
def check_kwargs(self):
|
|
64
77
|
if self.task and self.task not in Const.TASK_LIST:
|
|
65
|
-
raise
|
|
78
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
79
|
+
f"The task <{self.task}> is not in the {Const.TASK_LIST}.")
|
|
66
80
|
if self.level and self.level not in Const.LEVEL_LIST:
|
|
67
|
-
raise
|
|
81
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
82
|
+
f"The level <{self.level}> is not in the {Const.LEVEL_LIST}.")
|
|
68
83
|
if not self.dump_path:
|
|
69
|
-
raise
|
|
84
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
85
|
+
f"The dump_path not found.")
|
|
70
86
|
|
|
71
87
|
def check(self):
|
|
72
88
|
self.check_kwargs()
|
|
73
|
-
self._check_rank()
|
|
74
|
-
self._check_step()
|
|
75
89
|
return True
|
|
76
90
|
|
|
77
|
-
def check_model(self,
|
|
78
|
-
if self.level in ["L0", "mix"]
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
raise ValueError(f"step element {s} must be an integer and greater than or equal to 0.")
|
|
91
|
+
def check_model(self, instance, start_model):
|
|
92
|
+
if self.level not in ["L0", "mix"]:
|
|
93
|
+
if instance.model is not None or start_model is not None:
|
|
94
|
+
logger.warning_on_rank_0(
|
|
95
|
+
f"The current level is not L0 or mix level, so the model parameters will not be used.")
|
|
96
|
+
return
|
|
97
|
+
if start_model is None:
|
|
98
|
+
if instance.model is None:
|
|
99
|
+
logger.error_on_rank_0(
|
|
100
|
+
f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' argument.")
|
|
101
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
|
|
102
|
+
return
|
|
103
|
+
if isinstance(start_model, torch.nn.Module):
|
|
104
|
+
instance.model = start_model
|
|
105
|
+
else:
|
|
106
|
+
logger.error_on_rank_0(f"The 'model' parameter of start must be a torch.nn.Module type.")
|
|
107
|
+
raise MsprobeException(
|
|
108
|
+
MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
|