mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -14,13 +14,21 @@
|
|
|
14
14
|
# See the License for the specific language governing permissions and
|
|
15
15
|
# limitations under the License.
|
|
16
16
|
"""
|
|
17
|
+
import logging
|
|
17
18
|
import os
|
|
18
19
|
import random
|
|
19
20
|
import stat
|
|
21
|
+
import csv
|
|
22
|
+
import json
|
|
20
23
|
import torch
|
|
24
|
+
import torch.distributed as dist
|
|
21
25
|
import numpy as np
|
|
22
26
|
from functools import wraps
|
|
23
27
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
28
|
+
from msprobe.core.common.log import logger as common_logger
|
|
29
|
+
from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create, CompareException
|
|
30
|
+
from msprobe.core.common.file_check import FileCheckConst, change_mode, FileOpen
|
|
31
|
+
|
|
24
32
|
|
|
25
33
|
try:
|
|
26
34
|
import torch_npu
|
|
@@ -30,13 +38,8 @@ else:
|
|
|
30
38
|
is_gpu = False
|
|
31
39
|
|
|
32
40
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
if torch.__version__.startswith(version):
|
|
36
|
-
torch_without_guard_version = True
|
|
37
|
-
break
|
|
38
|
-
else:
|
|
39
|
-
torch_without_guard_version = False
|
|
41
|
+
torch_without_guard_version = torch.__version__ >= '2.1'
|
|
42
|
+
|
|
40
43
|
|
|
41
44
|
if not is_gpu and not torch_without_guard_version:
|
|
42
45
|
from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard
|
|
@@ -222,3 +225,76 @@ class Const:
|
|
|
222
225
|
CONVERT_API = {
|
|
223
226
|
"int32_to_int64": ["cross_entropy"]
|
|
224
227
|
}
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def get_tensor_rank(in_feat, out_feat):
|
|
231
|
+
if dist.is_initialized():
|
|
232
|
+
return dist.get_rank()
|
|
233
|
+
|
|
234
|
+
def get_tensor_rank_single(x):
|
|
235
|
+
if isinstance(x, (list, tuple)):
|
|
236
|
+
if len(x) > 0:
|
|
237
|
+
return get_tensor_rank_single(x[0])
|
|
238
|
+
elif isinstance(x, torch.Tensor):
|
|
239
|
+
device = x.device
|
|
240
|
+
if device.type != 'cpu':
|
|
241
|
+
return device.index
|
|
242
|
+
return None
|
|
243
|
+
|
|
244
|
+
in_rank = get_tensor_rank_single(in_feat)
|
|
245
|
+
out_rank = get_tensor_rank_single(out_feat)
|
|
246
|
+
tensor_rank = in_rank if in_rank else out_rank
|
|
247
|
+
return tensor_rank
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def get_rank_id():
|
|
251
|
+
if torch.distributed.is_initialized():
|
|
252
|
+
return torch.distributed.get_rank()
|
|
253
|
+
return 0
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def print_rank_0(message):
|
|
257
|
+
if dist.is_initialized():
|
|
258
|
+
if dist.get_rank() == 0:
|
|
259
|
+
logger.info(message)
|
|
260
|
+
else:
|
|
261
|
+
logger.info(message)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def load_pt(pt_path, to_cpu=False):
|
|
265
|
+
pt_path = os.path.realpath(pt_path)
|
|
266
|
+
check_file_or_directory_path(pt_path)
|
|
267
|
+
try:
|
|
268
|
+
if to_cpu:
|
|
269
|
+
pt = torch.load(pt_path, map_location=torch.device("cpu"))
|
|
270
|
+
else:
|
|
271
|
+
pt = torch.load(pt_path)
|
|
272
|
+
except Exception as e:
|
|
273
|
+
raise RuntimeError(f"load pt file {pt_path} failed") from e
|
|
274
|
+
return pt
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def save_pt(tensor, filepath):
|
|
278
|
+
filepath = os.path.realpath(filepath)
|
|
279
|
+
check_path_before_create(filepath)
|
|
280
|
+
try:
|
|
281
|
+
torch.save(tensor, filepath)
|
|
282
|
+
except Exception as e:
|
|
283
|
+
common_logger.error("Save pt file failed, please check according possible error causes: "
|
|
284
|
+
"1. out of disk space or disk error, "
|
|
285
|
+
"2. no permission to write files, etc.")
|
|
286
|
+
raise RuntimeError(f"save pt file {filepath} failed") from e
|
|
287
|
+
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _create_logger(level=logging.INFO):
|
|
291
|
+
logger_ = logging.getLogger()
|
|
292
|
+
logger_.setLevel(level)
|
|
293
|
+
ch = logging.StreamHandler()
|
|
294
|
+
ch.setLevel(level)
|
|
295
|
+
logger_.addHandler(ch)
|
|
296
|
+
return logger_
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
log_level = logging.DEBUG if os.environ.get("API_ACCURACY_CHECK_LOG_LEVEL") == "1" else logging.INFO
|
|
300
|
+
logger = _create_logger(log_level)
|
|
@@ -15,62 +15,17 @@
|
|
|
15
15
|
# limitations under the License.
|
|
16
16
|
"""
|
|
17
17
|
import os
|
|
18
|
-
import sys
|
|
19
|
-
import re
|
|
20
18
|
from msprobe.core.common.utils import CompareException, check_compare_param, \
|
|
21
|
-
check_configuration_param, task_dumppath_get
|
|
22
|
-
from msprobe.pytorch.compare.acc_compare import compare_core
|
|
19
|
+
check_configuration_param, task_dumppath_get
|
|
23
20
|
from msprobe.core.common.file_check import create_directory
|
|
24
|
-
from msprobe.
|
|
21
|
+
from msprobe.core.common.exceptions import FileCheckException
|
|
22
|
+
from msprobe.core.common.log import logger
|
|
23
|
+
from msprobe.core.common.const import Const
|
|
24
|
+
from msprobe.pytorch.compare.pt_compare import PTComparator
|
|
25
|
+
from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
28
|
-
def check_and_return_dir_contents(dump_dir, prefix):
|
|
29
|
-
"""
|
|
30
|
-
check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
|
|
31
|
-
pattern: ^{prefix}(?:0|[0-9][1-9]*)?$
|
|
32
|
-
|
|
33
|
-
Args:
|
|
34
|
-
dump_dir (str): dump dir
|
|
35
|
-
prefix (str): prefix for the patterns, prefix should be less than 20 characters and alphanumeric/-/_ only
|
|
36
|
-
|
|
37
|
-
Returns:
|
|
38
|
-
content [list]: dir contents
|
|
39
|
-
Raises:
|
|
40
|
-
CompareException: invalid path
|
|
41
|
-
ValueError: prefix not match the patterns
|
|
42
|
-
|
|
43
|
-
"""
|
|
44
|
-
check_regex_prefix_format_valid(prefix)
|
|
45
|
-
check_file_or_directory_path(dump_dir, True)
|
|
46
|
-
contents = os.listdir(dump_dir)
|
|
47
|
-
pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$')
|
|
48
|
-
for name in contents:
|
|
49
|
-
if not pattern.match(name):
|
|
50
|
-
logger.error(
|
|
51
|
-
f"dump_dir contains '{name}'. Expected '{prefix}'. This name is not in the format of dump "
|
|
52
|
-
f"output. Please check and delete irrelevant files in {dump_dir} and try again."
|
|
53
|
-
)
|
|
54
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
55
|
-
return contents
|
|
56
|
-
|
|
57
|
-
def extract_json(dirname, stack_json=False):
|
|
58
|
-
json_path = ''
|
|
59
|
-
for fname in os.listdir(dirname):
|
|
60
|
-
full_path = os.path.join(dirname, fname)
|
|
61
|
-
if full_path.endswith('.json'):
|
|
62
|
-
json_path = full_path
|
|
63
|
-
if not stack_json and 'stack' not in json_path:
|
|
64
|
-
break
|
|
65
|
-
if stack_json and 'stack' in json_path:
|
|
66
|
-
break
|
|
67
|
-
|
|
68
|
-
# Provide robustness on invalid directory inputs
|
|
69
|
-
if not json_path:
|
|
70
|
-
logger.error(f'No file is found in dump dir {dirname}. ')
|
|
71
|
-
raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
|
|
72
|
-
return json_path
|
|
73
|
-
|
|
74
29
|
if kwargs.get('suffix'):
|
|
75
30
|
logger.error("Argument 'suffix' is not supported for compare_distributed.")
|
|
76
31
|
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
@@ -86,26 +41,26 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
|
86
41
|
'or use compare() api and manually match the ranks.')
|
|
87
42
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
88
43
|
for nr, br in zip(npu_ranks, bench_ranks):
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
stack_json_path = extract_json(s_dir, stack_json=True)
|
|
44
|
+
npu_data_dir = os.path.join(npu_dump_dir, nr)
|
|
45
|
+
bench_data_dir = os.path.join(bench_dump_dir, br)
|
|
46
|
+
npu_path = extract_json(npu_data_dir, stack_json=False)
|
|
47
|
+
bench_path = extract_json(bench_data_dir, stack_json=False)
|
|
48
|
+
stack_path = extract_json(npu_data_dir, stack_json=True)
|
|
95
49
|
|
|
96
50
|
dump_result_param = {
|
|
97
|
-
'npu_json_path':
|
|
98
|
-
'bench_json_path':
|
|
99
|
-
'stack_json_path':
|
|
51
|
+
'npu_json_path': npu_path,
|
|
52
|
+
'bench_json_path': bench_path,
|
|
53
|
+
'stack_json_path': stack_path,
|
|
100
54
|
'is_print_compare_log': True
|
|
101
55
|
}
|
|
102
56
|
try:
|
|
103
57
|
summary_compare, md5_compare = task_dumppath_get(dump_result_param)
|
|
104
58
|
check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
|
|
105
59
|
create_directory(output_path)
|
|
106
|
-
check_compare_param(dump_result_param, output_path,
|
|
107
|
-
except CompareException as error:
|
|
60
|
+
check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, md5_compare=md5_compare)
|
|
61
|
+
except (CompareException, FileCheckException) as error:
|
|
108
62
|
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
109
|
-
|
|
110
|
-
|
|
63
|
+
raise CompareException(error.code) from error
|
|
64
|
+
pt_comparator = PTComparator()
|
|
65
|
+
pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare,
|
|
111
66
|
md5_compare=md5_compare, **kwargs)
|
msprobe/pytorch/compare/match.py
CHANGED
|
@@ -1,16 +1,13 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import
|
|
3
|
-
from msprobe.core.common.file_check import FileOpen
|
|
4
|
-
from msprobe.core.common.utils import CompareException
|
|
2
|
+
from msprobe.core.common.utils import CompareException, load_yaml
|
|
5
3
|
|
|
6
4
|
|
|
7
5
|
class AtenIrMapping():
|
|
8
6
|
def __init__(self):
|
|
9
7
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
10
8
|
yaml_path = os.path.join(cur_path, "mapping.yaml")
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
9
|
+
self.aten_mapping = load_yaml(yaml_path)
|
|
10
|
+
|
|
14
11
|
def match(self, op1, op2):
|
|
15
12
|
if "Aten" in op1 and "Aten" not in op2:
|
|
16
13
|
return self.match_op(op1, op2)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import os.path
|
|
2
|
+
import torch
|
|
3
|
+
from msprobe.core.common.const import FileCheckConst, Const
|
|
4
|
+
from msprobe.core.common.log import logger
|
|
5
|
+
from msprobe.core.common.exceptions import FileCheckException
|
|
6
|
+
from msprobe.core.compare.acc_compare import Comparator
|
|
7
|
+
from msprobe.core.common.utils import create_directory, check_configuration_param, task_dumppath_get, \
|
|
8
|
+
check_compare_param, FileChecker
|
|
9
|
+
from msprobe.core.common.utils import CompareException
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PTComparator (Comparator):
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self.frame_name = PTComparator.__name__
|
|
15
|
+
|
|
16
|
+
def read_npy_data(self, dir_path, file_name):
|
|
17
|
+
data_path = os.path.join(dir_path, file_name)
|
|
18
|
+
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
19
|
+
FileCheckConst.PT_SUFFIX, False)
|
|
20
|
+
data_path = path_checker.common_check()
|
|
21
|
+
data_value = torch.load(data_path, map_location=torch.device('cpu')).detach() # detach for less memory
|
|
22
|
+
if data_value.dtype == torch.bfloat16:
|
|
23
|
+
data_value = data_value.to(torch.float32)
|
|
24
|
+
data_value = data_value.numpy()
|
|
25
|
+
return data_value
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False):
|
|
29
|
+
try:
|
|
30
|
+
summary_compare, md5_compare = task_dumppath_get(input_param)
|
|
31
|
+
check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
|
|
32
|
+
create_directory(output_path)
|
|
33
|
+
check_compare_param(input_param, output_path, summary_compare, md5_compare)
|
|
34
|
+
except (CompareException, FileCheckException) as error:
|
|
35
|
+
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
36
|
+
raise CompareException(error.code) from error
|
|
37
|
+
pt_comparator = PTComparator()
|
|
38
|
+
pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
|
|
39
|
+
auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
|
|
40
|
+
md5_compare=md5_compare)
|
|
@@ -21,7 +21,7 @@ class DebuggerConfig:
|
|
|
21
21
|
self.acl_config = common_config.acl_config if common_config.acl_config else ""
|
|
22
22
|
self.is_forward_acl_dump = True
|
|
23
23
|
self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
|
|
24
|
-
self.
|
|
24
|
+
self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
|
|
25
25
|
self.framework = Const.PT_FRAMEWORK
|
|
26
26
|
|
|
27
27
|
if self.task == Const.FREE_BENCHMARK:
|
|
@@ -35,7 +35,16 @@ class DebuggerConfig:
|
|
|
35
35
|
"preheat_step": task_config.preheat_step if task_config.preheat_step else 15,
|
|
36
36
|
"max_sample": task_config.max_sample if task_config.max_sample else 20,
|
|
37
37
|
}
|
|
38
|
-
|
|
38
|
+
|
|
39
|
+
self.online_run_ut = False
|
|
40
|
+
if self.task == Const.TENSOR:
|
|
41
|
+
# dump api tensor and collaborate with online run_ut
|
|
42
|
+
self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
|
|
43
|
+
self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
|
|
44
|
+
self.tls_path = task_config.tls_path if task_config.tls_path else ""
|
|
45
|
+
self.host = task_config.host if task_config.host else ""
|
|
46
|
+
self.port = task_config.port if task_config.port else -1
|
|
47
|
+
|
|
39
48
|
self.check()
|
|
40
49
|
if self.step:
|
|
41
50
|
self.step.sort()
|
|
@@ -4,11 +4,14 @@ from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
|
|
|
4
4
|
from msprobe.pytorch.service import Service
|
|
5
5
|
from msprobe.pytorch.common.log import logger
|
|
6
6
|
from msprobe.pytorch.pt_config import parse_json_config
|
|
7
|
-
from msprobe.core.common.exceptions import
|
|
7
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
8
|
+
from msprobe.core.common.const import Const
|
|
9
|
+
from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
|
|
8
10
|
|
|
9
11
|
|
|
10
12
|
class PrecisionDebugger:
|
|
11
13
|
_instance = None
|
|
14
|
+
tasks_not_need_debugger = [Const.GRAD_PROBE]
|
|
12
15
|
|
|
13
16
|
def __new__(cls, *args, **kwargs):
|
|
14
17
|
if cls._instance is None:
|
|
@@ -27,9 +30,14 @@ class PrecisionDebugger:
|
|
|
27
30
|
step=None,
|
|
28
31
|
):
|
|
29
32
|
if not hasattr(self, "initialized"):
|
|
33
|
+
self.api_origin = False
|
|
30
34
|
self.initialized = True
|
|
31
35
|
self.model = self.check_model_valid(model)
|
|
32
36
|
common_config, task_config = parse_json_config(config_path, task)
|
|
37
|
+
self.task = common_config.task
|
|
38
|
+
if self.task == Const.GRAD_PROBE:
|
|
39
|
+
self.gm = GradientMonitor(common_config, task_config)
|
|
40
|
+
return
|
|
33
41
|
if step:
|
|
34
42
|
common_config.step = step
|
|
35
43
|
self.config = DebuggerConfig(
|
|
@@ -50,23 +58,35 @@ class PrecisionDebugger:
|
|
|
50
58
|
def check_model_valid(model):
|
|
51
59
|
if not model or isinstance(model, torch.nn.Module):
|
|
52
60
|
return model
|
|
53
|
-
raise
|
|
54
|
-
|
|
61
|
+
raise MsprobeException(
|
|
62
|
+
MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。"
|
|
55
63
|
)
|
|
56
64
|
|
|
57
65
|
@classmethod
|
|
58
66
|
def start(cls):
|
|
59
67
|
instance = cls._instance
|
|
68
|
+
if instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
69
|
+
return
|
|
60
70
|
if not instance:
|
|
61
71
|
raise Exception("No instance of PrecisionDebugger found.")
|
|
62
72
|
if instance.enable_dataloader:
|
|
63
73
|
logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
|
|
64
74
|
else:
|
|
65
|
-
instance.service.start(instance.model)
|
|
75
|
+
instance.service.start(instance.model, instance.api_origin)
|
|
76
|
+
instance.api_origin = False
|
|
77
|
+
|
|
78
|
+
# 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump
|
|
79
|
+
@classmethod
|
|
80
|
+
def forward_backward_dump_end(cls):
|
|
81
|
+
instance = cls._instance
|
|
82
|
+
instance.service.forward_backward_dump_end()
|
|
83
|
+
instance.api_origin = True
|
|
66
84
|
|
|
67
85
|
@classmethod
|
|
68
86
|
def stop(cls):
|
|
69
87
|
instance = cls._instance
|
|
88
|
+
if instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
89
|
+
return
|
|
70
90
|
if not instance:
|
|
71
91
|
raise Exception("PrecisionDebugger instance is not created.")
|
|
72
92
|
if instance.enable_dataloader:
|
|
@@ -76,10 +96,20 @@ class PrecisionDebugger:
|
|
|
76
96
|
|
|
77
97
|
@classmethod
|
|
78
98
|
def step(cls):
|
|
99
|
+
if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
100
|
+
return
|
|
79
101
|
if not cls._instance:
|
|
80
102
|
raise Exception("PrecisionDebugger instance is not created.")
|
|
81
103
|
cls._instance.service.step()
|
|
82
104
|
|
|
105
|
+
@classmethod
|
|
106
|
+
def monitor(cls, model):
|
|
107
|
+
if not cls._instance:
|
|
108
|
+
raise Exception("PrecisionDebugger instance is not created.")
|
|
109
|
+
if cls._instance.task != Const.GRAD_PROBE:
|
|
110
|
+
return
|
|
111
|
+
cls._instance.gm.monitor(model)
|
|
112
|
+
|
|
83
113
|
|
|
84
114
|
def iter_tracer(func):
|
|
85
115
|
def func_wrapper(*args, **kwargs):
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
|
|
9
9
|
**真实数据模式**:精度预检工具支持随机生成模式和真实数据模式,即在预检dump时可以选择由工具构造随机数进行输入获得dump数据或选择获取真实输入数据进行预检dump操作;随机生成模式执行效率高,可以快速获得结果,但数据精度低,只能大致判断精度问题;真实数据模式执行效率略低于随机生成模式,但是数据精度高,可以准确判断精度问题。
|
|
10
10
|
|
|
11
|
-
**工具支持PyTorch版本**:2.0/2.1/2.2。
|
|
11
|
+
**工具支持PyTorch版本**:1.11/2.0/2.1/2.2。
|
|
12
12
|
|
|
13
13
|
**工具特性**
|
|
14
14
|
|
|
@@ -21,7 +21,7 @@
|
|
|
21
21
|
精度预检操作流程如下:
|
|
22
22
|
|
|
23
23
|
1. 在NPU和GPU环境下分别安装msprobe工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
|
|
24
|
-
2. 在NPU训练脚本内添加msprobe工具dump接口PrecisionDebugger
|
|
24
|
+
2. 在NPU训练脚本内添加msprobe工具dump接口PrecisionDebugger,采集待预检数据。详见《[精度数据采集](./dump.md)》,注意需要配置level="L1"。
|
|
25
25
|
3. 将NPU环境下dump的预检数据拷贝至GPU环境。
|
|
26
26
|
4. 在NPU和GPU环境下分别执行run_ut,生成结果用于最终api_precision_compare操作的输入。详见“**run_ut预检操作**”。
|
|
27
27
|
5. 将NPU和GPU执行run_ut生成的`accuracy_checking_details_{timestamp}.csv`结果文件拷贝至同一环境下。
|
|
@@ -51,10 +51,12 @@ run_ut预检操作包括如下场景:
|
|
|
51
51
|
| -api_info或--api_info_file | 指定API信息文件dump.json。 | 是 |
|
|
52
52
|
| -save_error_data | 保存精度未达标的API输入输出数据。 | 否 |
|
|
53
53
|
| -o或--out_path | 指定run_ut执行结果存盘路径,默认“./”(相对于run_ut的路径)。 | 否 |
|
|
54
|
+
| | | |
|
|
54
55
|
| -j或--jit_compile | 开启jit编译。 | 否 |
|
|
55
56
|
| -d或--device | 指定Device ID,选择UT代码运行所在的卡,默认值为0。 | 否 |
|
|
56
57
|
| -csv_path或--result_csv_path | 指定本次运行中断时生成的`accuracy_checking_result_{timestamp}.csv`文件路径,执行run_ut中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的`accuracy_checking_result_{timestamp}.csv`文件。详见“**断点续检**”。 | run_ut操作中断后继续执行场景下必选 |
|
|
57
58
|
| -f或--filter_api | 过滤模型中除最大值和最小值以外其他参数和结构相同的API。适用于模型较大且重复API较多的场景。 | 否 |
|
|
59
|
+
| -config或--config_path | 指定预检操作过程中的额外配置(包括黑名单、白名单等)的[config.json](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe/config)文件,默认未配置。config.json文件的配置可参考《[配置文件说明](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/config/README.md#pytorch场景task配置为run_ut)》。 | 否 |
|
|
58
60
|
|
|
59
61
|
run_ut执行结果包括`accuracy_checking_result_{timestamp}.csv`和`accuracy_checking_details_{timestamp}.csv`两个文件。`accuracy_checking_result_{timestamp}.csv`是API粒度的,标明每个API是否通过测试。建议用户先查看`accuracy_checking_result_{timestamp}.csv`文件,对于其中没有通过测试的或者特定感兴趣的API,根据其API name字段在`accuracy_checking_details_{timestamp}.csv`中查询其各个输出的达标情况以及比较指标。详细介绍请参见“**预检结果**”。
|
|
60
62
|
|
|
@@ -64,7 +66,7 @@ run_ut预检操作包括如下场景:
|
|
|
64
66
|
msprobe -f pytorch run_ut -api_info ./dump.json -save_error_data
|
|
65
67
|
```
|
|
66
68
|
|
|
67
|
-
数据默认会存盘到'./ut_error_data{timestamp}'路径下(相对于启动run_ut
|
|
69
|
+
数据默认会存盘到'./ut_error_data{timestamp}'路径下(相对于启动run_ut的路径),有需要的话,用户可以通过error_data_path参数来配置保存路径,error_data_path参数在[config.json](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe/config)文件或config.yaml文件配置,config.json文件需要在run_ut操作时通过-config参数指定,config.yaml文件详见“**config.yaml文件说明**”。
|
|
68
70
|
|
|
69
71
|
#### 使用multi_run_ut.py执行多线程预检
|
|
70
72
|
|
|
@@ -99,23 +101,65 @@ msprobe -f pytorch multi_run_ut -api_info ./dump.json -n 32 -d 0 1 2 3
|
|
|
99
101
|
msprobe -f pytorch run_ut -api_info ./dump.json -csv_path /home/xxx/ut/accuracy_checking_result_{timestamp}.csv
|
|
100
102
|
```
|
|
101
103
|
|
|
102
|
-
#### API
|
|
104
|
+
#### API预检黑名单和白名单
|
|
103
105
|
|
|
104
|
-
run_ut过程支持API
|
|
106
|
+
run_ut过程支持API预检黑名单和白名单,通过如下文件配置black_list(黑名单)或white_list(白名单)参数来指定不需要或需要预检的API名称:
|
|
105
107
|
|
|
106
|
-
|
|
108
|
+
- 配置[config.json](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe/config)文件,config.json文件需要在run_ut操作时通过-config参数指定。
|
|
109
|
+
- 配置config.yaml文件,详见“**config.yaml文件说明**”。
|
|
110
|
+
|
|
111
|
+
config.json文件的优先级高于config.yaml文件,即执行config.json文件时,config.yaml文件的配置不生效。
|
|
107
112
|
|
|
108
113
|
### config.yaml文件说明
|
|
109
114
|
|
|
110
|
-
config.yaml文件可以通过配置参数来控制dump和run_ut
|
|
115
|
+
config.yaml文件可以通过配置参数来控制dump和run_ut操作的白名单、黑名单等功能。操作步骤如下:
|
|
116
|
+
|
|
117
|
+
1. 查找msprobe工具安装路径。
|
|
118
|
+
|
|
119
|
+
```bash
|
|
120
|
+
pip show mindstudio-probe
|
|
121
|
+
```
|
|
122
|
+
|
|
123
|
+
输出结果如下示例:
|
|
124
|
+
|
|
125
|
+
```bash
|
|
126
|
+
Name: mindstudio-probe
|
|
127
|
+
Version: 1.0
|
|
128
|
+
Summary: This is a pytorch precision comparison tools
|
|
129
|
+
Home-page:
|
|
130
|
+
Author:
|
|
131
|
+
Author-email:
|
|
132
|
+
License:
|
|
133
|
+
Location: /home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages
|
|
134
|
+
Requires: numpy, openpyxl, pandas, pyyaml, rich, tqdm, wheel
|
|
135
|
+
Required-by:
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
Location字段为msprobe工具的安装路径,那么config.yaml文件位置为/home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages/msprobe/pytorch/api_accuracy_checker/config.yaml
|
|
139
|
+
|
|
140
|
+
2. 进入config.yaml文件
|
|
141
|
+
|
|
142
|
+
```bash
|
|
143
|
+
vi /home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages/msprobe/pytorch/api_accuracy_checker/config.yaml
|
|
144
|
+
```
|
|
145
|
+
|
|
146
|
+
3. 修改config.yaml文件参数。
|
|
147
|
+
|
|
148
|
+
```yaml
|
|
149
|
+
white_list: []
|
|
150
|
+
black_list: []
|
|
151
|
+
error_data_path: './'
|
|
152
|
+
precision: 14
|
|
153
|
+
```
|
|
111
154
|
|
|
112
|
-
|
|
155
|
+
| 参数名称 | 说明 | 是否必选 |
|
|
156
|
+
| --------------- | ------------------------------------------------------------ | -------- |
|
|
157
|
+
| white_list | API dump白名单,仅对指定的API进行dump。参数示例:white_list=["conv1d", "conv2d"]。默认未配置白名单,即dump全量API数据。 | 否 |
|
|
158
|
+
| black_list | API dump黑名单,被指定的API不进行dump。参数示例:black_list=["conv1d", "conv2d"]。默认未配置黑名单,即dump全量API数据。 | 否 |
|
|
159
|
+
| error_data_path | 配置保存精度未达标的API输入输出数据路径。参数示例"error_data_path": "./"。默认为当前路径。 | 否 |
|
|
160
|
+
| precision | 浮点数表示位数,默认取小数点后14位。 | 否 |
|
|
113
161
|
|
|
114
|
-
|
|
115
|
-
| --------------- | ------------------------------------------------------------ | -------- |
|
|
116
|
-
| white_list | API dump白名单,指定dump具体API数据,也可以直接配置预检的API白名单,详细请参见“**API预检白名单**”。参数示例:white_list=["conv1d", "conv2d"]。默认未配置白名单,即dump全量API数据。 | 否 |
|
|
117
|
-
| error_data_path | 配置保存精度未达标的API输入输出数据路径。 | 否 |
|
|
118
|
-
| precision | 浮点数表示位数,默认取小数点后14位。 | 否 |
|
|
162
|
+
说明:white_list和black_list同时配置时,二者配置的API名单若无交集,则白名单生效,若API名单存在交集,则白名单排除的部分以及交集的API不进行dump。
|
|
119
163
|
|
|
120
164
|
## 预检结果
|
|
121
165
|
|