mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +131 -237
- msprobe/__init__.py +16 -1
- msprobe/{config/config.json → config.json} +47 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +58 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +402 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +523 -283
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +86 -69
- msprobe/core/common/utils.py +371 -616
- msprobe/core/common_config.py +78 -71
- msprobe/core/compare/acc_compare.py +472 -298
- msprobe/core/compare/check.py +180 -95
- msprobe/core/compare/compare_cli.py +69 -49
- msprobe/core/compare/highlight.py +259 -222
- msprobe/core/compare/multiprocessing_compute.py +174 -149
- msprobe/core/compare/npy_compare.py +310 -295
- msprobe/core/compare/utils.py +464 -429
- msprobe/core/data_dump/data_collector.py +153 -144
- msprobe/core/data_dump/data_processor/base.py +337 -293
- msprobe/core/data_dump/data_processor/factory.py +76 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
- msprobe/core/data_dump/json_writer.py +117 -116
- msprobe/core/data_dump/scope.py +194 -178
- msprobe/core/grad_probe/constant.py +74 -70
- msprobe/core/grad_probe/grad_compare.py +170 -175
- msprobe/core/grad_probe/utils.py +77 -52
- msprobe/docs/01.installation.md +99 -0
- msprobe/docs/02.config_introduction.md +137 -0
- msprobe/docs/03.config_examples.md +237 -0
- msprobe/docs/04.acl_config_examples.md +78 -0
- msprobe/docs/05.data_dump_PyTorch.md +326 -0
- msprobe/docs/06.data_dump_MindSpore.md +285 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
- msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
- msprobe/docs/FAQ.md +189 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +2 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +58 -34
- msprobe/mindspore/common/const.py +108 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +97 -57
- msprobe/mindspore/compare/distributed_compare.py +62 -75
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +357 -117
- msprobe/mindspore/compare/ms_graph_compare.py +364 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +69 -74
- msprobe/mindspore/debugger/precision_debugger.py +150 -107
- msprobe/mindspore/dump/dump_tool_factory.py +50 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
- msprobe/mindspore/dump/jit_dump.py +96 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
- msprobe/mindspore/free_benchmark/common/config.py +27 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
- msprobe/mindspore/free_benchmark/common/utils.py +85 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
- msprobe/mindspore/grad_probe/global_context.py +100 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +297 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +23 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +30 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
- msprobe/pytorch/bench_functions/linear.py +27 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
- msprobe/pytorch/bench_functions/rms_norm.py +30 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
- msprobe/pytorch/bench_functions/swiglu.py +70 -55
- msprobe/pytorch/common/__init__.py +17 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +33 -32
- msprobe/pytorch/common/parse_json.py +54 -39
- msprobe/pytorch/common/utils.py +310 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +49 -33
- msprobe/pytorch/compare/pt_compare.py +82 -40
- msprobe/pytorch/debugger/debugger_config.py +108 -95
- msprobe/pytorch/debugger/precision_debugger.py +173 -125
- msprobe/pytorch/free_benchmark/__init__.py +23 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +65 -37
- msprobe/pytorch/free_benchmark/common/params.py +144 -129
- msprobe/pytorch/free_benchmark/common/utils.py +118 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
- msprobe/pytorch/free_benchmark/main.py +120 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
- msprobe/pytorch/function_factory.py +91 -75
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +166 -161
- msprobe/pytorch/hook_module/hook_module.py +118 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +28 -29
- msprobe/pytorch/hook_module/wrap_aten.py +111 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
- msprobe/pytorch/hook_module/wrap_functional.py +104 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
- msprobe/pytorch/hook_module/wrap_torch.py +84 -86
- msprobe/pytorch/hook_module/wrap_vf.py +60 -62
- msprobe/pytorch/module_processer.py +153 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +235 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
- msprobe/pytorch/online_dispatch/utils.py +127 -146
- msprobe/pytorch/parse.py +19 -4
- msprobe/pytorch/parse_tool/cli.py +31 -32
- msprobe/pytorch/parse_tool/lib/compare.py +259 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
- msprobe/pytorch/parse_tool/lib/utils.py +320 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +317 -187
- msprobe/pytorch/service.py +311 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -1,34 +1,49 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
return
|
|
34
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
import mindspore as ms
|
|
19
|
+
from mindspore import Tensor, ops
|
|
20
|
+
|
|
21
|
+
from msprobe.mindspore.common.const import Const
|
|
22
|
+
from msprobe.mindspore.common.log import logger
|
|
23
|
+
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
24
|
+
from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ImprovePrecisionPerturbation(BasePerturbation):
|
|
28
|
+
|
|
29
|
+
def improve_tensor_precision(self, target_tensor):
|
|
30
|
+
if isinstance(target_tensor, Tensor) and ops.is_floating_point(target_tensor) and \
|
|
31
|
+
target_tensor.dtype not in [ms.float64, ms.float32]:
|
|
32
|
+
self.is_fuzzed = True
|
|
33
|
+
return target_tensor.to(ms.float32)
|
|
34
|
+
if isinstance(target_tensor, dict):
|
|
35
|
+
return {k: self.improve_tensor_precision(v) for k, v in target_tensor.items()}
|
|
36
|
+
if isinstance(target_tensor, (tuple, list)):
|
|
37
|
+
return type(target_tensor)([self.improve_tensor_precision(v) for v in target_tensor])
|
|
38
|
+
return target_tensor
|
|
39
|
+
|
|
40
|
+
def handle(self, params: HandlerParams) -> Any:
|
|
41
|
+
args = self.improve_tensor_precision(params.args)
|
|
42
|
+
kwargs = self.improve_tensor_precision(params.kwargs)
|
|
43
|
+
fuzzed_value = args
|
|
44
|
+
if self.api_name in Const.COMMUNICATION_API_LIST:
|
|
45
|
+
params.fuzzed_value = fuzzed_value
|
|
46
|
+
if not self.is_fuzzed:
|
|
47
|
+
logger.warning(f"{self.api_name} can not improve precision.")
|
|
48
|
+
return False
|
|
49
|
+
return params.original_func(*args, **kwargs)
|
|
@@ -1,12 +1,27 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
19
|
+
from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class NoChangePerturbation(BasePerturbation):
|
|
23
|
+
|
|
24
|
+
def handle(self, params: HandlerParams) -> Any:
|
|
25
|
+
params.fuzzed_value = params.args[params.index]
|
|
26
|
+
self.is_fuzzed = True
|
|
27
|
+
return self.get_fuzzed_result(params)
|
|
@@ -1,27 +1,44 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
17
|
+
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
18
|
+
from msprobe.mindspore.free_benchmark.perturbation.add_noise import AddNoisePerturbation
|
|
19
|
+
from msprobe.mindspore.free_benchmark.perturbation.bit_noise import BitNoisePerturbation
|
|
20
|
+
from msprobe.mindspore.free_benchmark.perturbation.exchange_value import ExchangeValuePerturbation
|
|
21
|
+
from msprobe.mindspore.free_benchmark.perturbation.improve_precision import ImprovePrecisionPerturbation
|
|
22
|
+
from msprobe.mindspore.free_benchmark.perturbation.no_change import NoChangePerturbation
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class PerturbationFactory:
|
|
26
|
+
"""
|
|
27
|
+
扰动工厂类
|
|
28
|
+
|
|
29
|
+
"""
|
|
30
|
+
perturbations = {
|
|
31
|
+
FreeBenchmarkConst.IMPROVE_PRECISION: ImprovePrecisionPerturbation,
|
|
32
|
+
FreeBenchmarkConst.ADD_NOISE: AddNoisePerturbation,
|
|
33
|
+
FreeBenchmarkConst.BIT_NOISE: BitNoisePerturbation,
|
|
34
|
+
FreeBenchmarkConst.NO_CHANGE: NoChangePerturbation,
|
|
35
|
+
FreeBenchmarkConst.EXCHANGE_VALUE: ExchangeValuePerturbation
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def create(api_name: str):
|
|
40
|
+
perturbation = PerturbationFactory.perturbations.get(Config.pert_type)
|
|
41
|
+
if perturbation:
|
|
42
|
+
return perturbation(api_name)
|
|
43
|
+
else:
|
|
44
|
+
raise Exception(f'{Config.pert_type} is a invalid perturbation type')
|
|
@@ -1,33 +1,48 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.mindspore.common.const import Const
|
|
17
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
18
|
+
from msprobe.mindspore.free_benchmark.api_pynative_self_check import ApiPyNativeSelFCheck
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SelfCheckToolFactory:
|
|
22
|
+
tools = {
|
|
23
|
+
Const.CELL: {
|
|
24
|
+
Const.GRAPH_KBYK_MODE: None,
|
|
25
|
+
Const.GRAPH_GE_MODE: None,
|
|
26
|
+
Const.PYNATIVE_MODE: None
|
|
27
|
+
},
|
|
28
|
+
Const.API: {
|
|
29
|
+
Const.GRAPH_KBYK_MODE: None,
|
|
30
|
+
Const.GRAPH_GE_MODE: None,
|
|
31
|
+
Const.PYNATIVE_MODE: ApiPyNativeSelFCheck
|
|
32
|
+
},
|
|
33
|
+
Const.KERNEL: {
|
|
34
|
+
Const.GRAPH_KBYK_MODE: None,
|
|
35
|
+
Const.GRAPH_GE_MODE: None,
|
|
36
|
+
Const.PYNATIVE_MODE: None
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def create(config: DebuggerConfig):
|
|
42
|
+
tool = SelfCheckToolFactory.tools.get(config.level)
|
|
43
|
+
if not tool:
|
|
44
|
+
raise Exception(f"{config.level} is not supported.")
|
|
45
|
+
tool = tool.get(config.execution_mode)
|
|
46
|
+
if not tool:
|
|
47
|
+
raise Exception(f"Task free_benchmark is not supported in this mode: {config.execution_mode}.")
|
|
48
|
+
return tool(config)
|
|
@@ -1,91 +1,100 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import threading
|
|
3
|
-
from typing import Dict, Union
|
|
4
|
-
|
|
5
|
-
from msprobe.core.grad_probe.utils import check_str
|
|
6
|
-
from msprobe.core.grad_probe.constant import GradConst
|
|
7
|
-
from msprobe.
|
|
8
|
-
from msprobe.core.common.
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
GradConst.
|
|
18
|
-
GradConst.
|
|
19
|
-
GradConst.
|
|
20
|
-
GradConst.
|
|
21
|
-
GradConst.
|
|
22
|
-
GradConst.
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
cls.
|
|
29
|
-
cls.
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
level =
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
self._set_input_list(config_dict, GradConst.
|
|
42
|
-
self._set_input_list(config_dict, GradConst.
|
|
43
|
-
self._set_input_list(config_dict, GradConst.
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
output_path =
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
self._setting
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
dump_step_list
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
dump_rank_list
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
if dtype == int:
|
|
77
|
-
type_str = "integer"
|
|
78
|
-
elif dtype == float:
|
|
79
|
-
type_str = "float"
|
|
80
|
-
else:
|
|
81
|
-
type_str = "string"
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
1
|
+
import os
|
|
2
|
+
import threading
|
|
3
|
+
from typing import Dict, Union, Tuple
|
|
4
|
+
|
|
5
|
+
from msprobe.core.grad_probe.utils import check_str, check_bounds_element
|
|
6
|
+
from msprobe.core.grad_probe.constant import GradConst
|
|
7
|
+
from msprobe.mindspore.common.log import logger
|
|
8
|
+
from msprobe.core.common.file_utils import create_directory, check_path_before_create
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GlobalContext:
|
|
12
|
+
|
|
13
|
+
_instance = None
|
|
14
|
+
_instance_lock = threading.Lock()
|
|
15
|
+
_setting = {
|
|
16
|
+
GradConst.LEVEL: None,
|
|
17
|
+
GradConst.PARAM_LIST: None,
|
|
18
|
+
GradConst.STEP: None,
|
|
19
|
+
GradConst.RANK: None,
|
|
20
|
+
GradConst.CURRENT_STEP: 0,
|
|
21
|
+
GradConst.BOUNDS: [-1, 0, 1],
|
|
22
|
+
GradConst.OUTPUT_PATH: None
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
def __new__(cls, *args, **kwargs):
|
|
26
|
+
if cls._instance is None:
|
|
27
|
+
cls._instance_lock.acquire()
|
|
28
|
+
cls._instance = object.__new__(cls)
|
|
29
|
+
cls._instance_lock.release()
|
|
30
|
+
return cls._instance
|
|
31
|
+
|
|
32
|
+
def init_context(self, config_dict: Dict):
|
|
33
|
+
level = config_dict.get(GradConst.LEVEL)
|
|
34
|
+
check_str(level, variable_name="level in yaml")
|
|
35
|
+
if level in GradConst.SUPPORTED_LEVEL:
|
|
36
|
+
self._setting[GradConst.LEVEL] = config_dict.get(GradConst.LEVEL)
|
|
37
|
+
else:
|
|
38
|
+
raise ValueError("Invalid level set in config yaml file, level option: L0, L1, L2")
|
|
39
|
+
|
|
40
|
+
self._set_input_list(config_dict, GradConst.PARAM_LIST, str)
|
|
41
|
+
self._set_input_list(config_dict, GradConst.BOUNDS, (float, int), element_check=check_bounds_element)
|
|
42
|
+
self._set_input_list(config_dict, GradConst.STEP, int)
|
|
43
|
+
self._set_input_list(config_dict, GradConst.RANK, int)
|
|
44
|
+
|
|
45
|
+
output_path = config_dict.get(GradConst.OUTPUT_PATH)
|
|
46
|
+
check_str(output_path, variable_name="output_path in yaml")
|
|
47
|
+
try:
|
|
48
|
+
check_path_before_create(output_path)
|
|
49
|
+
except RuntimeError as err:
|
|
50
|
+
raise ValueError(f"Invalid output_path: {output_path}. The error message is {err}.") from err
|
|
51
|
+
self._setting[GradConst.OUTPUT_PATH] = output_path
|
|
52
|
+
if not os.path.isdir(self._setting.get(GradConst.OUTPUT_PATH)):
|
|
53
|
+
create_directory(self._setting.get(GradConst.OUTPUT_PATH))
|
|
54
|
+
else:
|
|
55
|
+
logger.warning("The output_path exists, the data will be covered.")
|
|
56
|
+
|
|
57
|
+
def get_context(self, key: str):
|
|
58
|
+
if key not in self._setting:
|
|
59
|
+
logger.warning(f"Unrecognized {key}.")
|
|
60
|
+
return self._setting.get(key)
|
|
61
|
+
|
|
62
|
+
def update_step(self):
|
|
63
|
+
self._setting[GradConst.CURRENT_STEP] += 1
|
|
64
|
+
|
|
65
|
+
def step_need_dump(self, step):
|
|
66
|
+
dump_step_list = self.get_context(GradConst.STEP)
|
|
67
|
+
return (not dump_step_list) or (step in dump_step_list)
|
|
68
|
+
|
|
69
|
+
def rank_need_dump(self, rank):
|
|
70
|
+
dump_rank_list = self.get_context(GradConst.RANK)
|
|
71
|
+
return (not dump_rank_list) or (rank in dump_rank_list)
|
|
72
|
+
|
|
73
|
+
def _get_type_str(self, dtype: Union[int, str, float, Tuple[int, str, float]]):
|
|
74
|
+
if isinstance(dtype, tuple):
|
|
75
|
+
return "/".join([self._get_type_str(element) for element in dtype])
|
|
76
|
+
if dtype == int:
|
|
77
|
+
type_str = "integer"
|
|
78
|
+
elif dtype == float:
|
|
79
|
+
type_str = "float"
|
|
80
|
+
else:
|
|
81
|
+
type_str = "string"
|
|
82
|
+
return type_str
|
|
83
|
+
|
|
84
|
+
def _set_input_list(self, config_dict: Dict, name: str,
|
|
85
|
+
dtype: Union[int, str, float, Tuple[int, str, float]], element_check=None):
|
|
86
|
+
value = config_dict.get(name)
|
|
87
|
+
type_str = self._get_type_str(dtype)
|
|
88
|
+
if value and isinstance(value, list):
|
|
89
|
+
for val in value:
|
|
90
|
+
if not isinstance(val, dtype):
|
|
91
|
+
logger.warning(f"Invalid {name} which must be None or list of {type_str}")
|
|
92
|
+
return
|
|
93
|
+
if element_check and not element_check(val):
|
|
94
|
+
logger.warning(f"Given {name} violates some rules.")
|
|
95
|
+
return
|
|
96
|
+
self._setting[name] = value
|
|
97
|
+
else:
|
|
98
|
+
logger.warning(f"{name} is None or not a list with valid items, use default value.")
|
|
99
|
+
|
|
100
|
+
grad_context = GlobalContext()
|