mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +131 -237
- msprobe/__init__.py +16 -1
- msprobe/{config/config.json → config.json} +47 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +58 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +402 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +523 -283
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +86 -69
- msprobe/core/common/utils.py +371 -616
- msprobe/core/common_config.py +78 -71
- msprobe/core/compare/acc_compare.py +472 -298
- msprobe/core/compare/check.py +180 -95
- msprobe/core/compare/compare_cli.py +69 -49
- msprobe/core/compare/highlight.py +259 -222
- msprobe/core/compare/multiprocessing_compute.py +174 -149
- msprobe/core/compare/npy_compare.py +310 -295
- msprobe/core/compare/utils.py +464 -429
- msprobe/core/data_dump/data_collector.py +153 -144
- msprobe/core/data_dump/data_processor/base.py +337 -293
- msprobe/core/data_dump/data_processor/factory.py +76 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
- msprobe/core/data_dump/json_writer.py +117 -116
- msprobe/core/data_dump/scope.py +194 -178
- msprobe/core/grad_probe/constant.py +74 -70
- msprobe/core/grad_probe/grad_compare.py +170 -175
- msprobe/core/grad_probe/utils.py +77 -52
- msprobe/docs/01.installation.md +99 -0
- msprobe/docs/02.config_introduction.md +137 -0
- msprobe/docs/03.config_examples.md +237 -0
- msprobe/docs/04.acl_config_examples.md +78 -0
- msprobe/docs/05.data_dump_PyTorch.md +326 -0
- msprobe/docs/06.data_dump_MindSpore.md +285 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
- msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
- msprobe/docs/FAQ.md +189 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +2 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +58 -34
- msprobe/mindspore/common/const.py +108 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +97 -57
- msprobe/mindspore/compare/distributed_compare.py +62 -75
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +357 -117
- msprobe/mindspore/compare/ms_graph_compare.py +364 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +69 -74
- msprobe/mindspore/debugger/precision_debugger.py +150 -107
- msprobe/mindspore/dump/dump_tool_factory.py +50 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
- msprobe/mindspore/dump/jit_dump.py +96 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
- msprobe/mindspore/free_benchmark/common/config.py +27 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
- msprobe/mindspore/free_benchmark/common/utils.py +85 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
- msprobe/mindspore/grad_probe/global_context.py +100 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +297 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +23 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +30 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
- msprobe/pytorch/bench_functions/linear.py +27 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
- msprobe/pytorch/bench_functions/rms_norm.py +30 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
- msprobe/pytorch/bench_functions/swiglu.py +70 -55
- msprobe/pytorch/common/__init__.py +17 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +33 -32
- msprobe/pytorch/common/parse_json.py +54 -39
- msprobe/pytorch/common/utils.py +310 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +49 -33
- msprobe/pytorch/compare/pt_compare.py +82 -40
- msprobe/pytorch/debugger/debugger_config.py +108 -95
- msprobe/pytorch/debugger/precision_debugger.py +173 -125
- msprobe/pytorch/free_benchmark/__init__.py +23 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +65 -37
- msprobe/pytorch/free_benchmark/common/params.py +144 -129
- msprobe/pytorch/free_benchmark/common/utils.py +118 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
- msprobe/pytorch/free_benchmark/main.py +120 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
- msprobe/pytorch/function_factory.py +91 -75
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +166 -161
- msprobe/pytorch/hook_module/hook_module.py +118 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +28 -29
- msprobe/pytorch/hook_module/wrap_aten.py +111 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
- msprobe/pytorch/hook_module/wrap_functional.py +104 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
- msprobe/pytorch/hook_module/wrap_torch.py +84 -86
- msprobe/pytorch/hook_module/wrap_vf.py +60 -62
- msprobe/pytorch/module_processer.py +153 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +235 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
- msprobe/pytorch/online_dispatch/utils.py +127 -146
- msprobe/pytorch/parse.py +19 -4
- msprobe/pytorch/parse_tool/cli.py +31 -32
- msprobe/pytorch/parse_tool/lib/compare.py +259 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
- msprobe/pytorch/parse_tool/lib/utils.py +320 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +317 -187
- msprobe/pytorch/service.py +311 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -1,45 +1,60 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from abc import abstractmethod
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
21
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NpuBaseLayer(BaseLayer):
|
|
25
|
+
def __init__(self, api_name: str) -> None:
|
|
26
|
+
super().__init__(api_name)
|
|
27
|
+
self.perturbed_value = None # 扰动的元素
|
|
28
|
+
self.is_added = False # 标记当前算子输入是否调整
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def perturbed_result(params: DataParams) -> Any:
|
|
32
|
+
args_front = params.args[: params.valid_input_index]
|
|
33
|
+
args_rear = params.args[params.valid_input_index + 1:]
|
|
34
|
+
# 此处会将有inplace属性的算子换为非inplace
|
|
35
|
+
if "inplace" in params.kwargs:
|
|
36
|
+
params.kwargs["inplace"] = False
|
|
37
|
+
params.perturbed_result = params.origin_func(
|
|
38
|
+
*args_front, params.perturbed_value, *args_rear, **params.kwargs
|
|
39
|
+
)
|
|
40
|
+
return params.perturbed_result
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def handle(self, params: DataParams) -> Any:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
def pre_check(self, tensor_obj):
|
|
47
|
+
"""
|
|
48
|
+
检查张量是否符合标准(float类型且最大值大于对应精度最小值)
|
|
49
|
+
"""
|
|
50
|
+
# 只针对第一个满足要求的添加扰动
|
|
51
|
+
if self.is_added:
|
|
52
|
+
return False
|
|
53
|
+
if not torch.is_floating_point(tensor_obj):
|
|
54
|
+
return False
|
|
55
|
+
if not self._check_details(tensor_obj):
|
|
56
|
+
return False
|
|
57
|
+
return True
|
|
58
|
+
|
|
59
|
+
def _check_details(self, tensor_obj):
|
|
60
|
+
return True
|
|
@@ -1,19 +1,34 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
18
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
19
|
+
from msprobe.pytorch.free_benchmark.common.utils import Tools
|
|
20
|
+
from msprobe.pytorch.free_benchmark.common.enums import DeviceType
|
|
21
|
+
from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CpuLayer(BaseLayer):
|
|
25
|
+
|
|
26
|
+
def handle(self, params: DataParams):
|
|
27
|
+
|
|
28
|
+
logger.info_on_rank_0(
|
|
29
|
+
f"[msprobe] Free benchmark: Perturbation is to_cpu of {self.api_name}."
|
|
30
|
+
)
|
|
31
|
+
new_args = Tools.convert_device_and_dtype(params.args, DeviceType.CPU, change_dtype=True)
|
|
32
|
+
new_kwargs = Tools.convert_device_and_dtype(params.kwargs, DeviceType.CPU, change_dtype=True)
|
|
33
|
+
params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
|
|
34
|
+
return params.perturbed_result
|
|
@@ -1,217 +1,256 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
if
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import math
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
from typing import Any, Optional, Tuple
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import torch
|
|
22
|
+
from msprobe.core.common.const import Const
|
|
23
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
24
|
+
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
25
|
+
from msprobe.pytorch.free_benchmark.common.enums import (
|
|
26
|
+
FuzzThreshold,
|
|
27
|
+
NormType,
|
|
28
|
+
PerturbationMode,
|
|
29
|
+
)
|
|
30
|
+
from msprobe.pytorch.free_benchmark.common.params import (
|
|
31
|
+
DataParams,
|
|
32
|
+
HandlerParams,
|
|
33
|
+
make_unequal_row,
|
|
34
|
+
)
|
|
35
|
+
from msprobe.pytorch.free_benchmark.common.utils import Tools, TorchC
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class FuzzHandler(ABC):
|
|
39
|
+
def __init__(self, params: HandlerParams) -> None:
|
|
40
|
+
self.params = params
|
|
41
|
+
self.unequal_rows = []
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def pre_process(origin_ouput, perturbed_output):
|
|
45
|
+
if (
|
|
46
|
+
isinstance(origin_ouput, tuple)
|
|
47
|
+
and hasattr(origin_ouput, "values")
|
|
48
|
+
and hasattr(origin_ouput, "indices")
|
|
49
|
+
):
|
|
50
|
+
origin_ouput = origin_ouput.values
|
|
51
|
+
perturbed_output = perturbed_output.values
|
|
52
|
+
if hasattr(perturbed_output, "dtype"):
|
|
53
|
+
abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
|
|
54
|
+
perturbed_output.dtype, FuzzThreshold.F32_THD
|
|
55
|
+
)
|
|
56
|
+
else:
|
|
57
|
+
abs_tol = FuzzThreshold.F32_THD
|
|
58
|
+
return (
|
|
59
|
+
origin_ouput.to(perturbed_output.dtype).to(perturbed_output.device),
|
|
60
|
+
perturbed_output,
|
|
61
|
+
abs_tol,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def tensor_split_for_error_calculate(origin_output, perturbed_output):
|
|
66
|
+
"""
|
|
67
|
+
对将投入误差值计算的扰动前后输出张量进行分块
|
|
68
|
+
:param origin_output: 原始输出
|
|
69
|
+
:param perturbed_output: 扰动后输出
|
|
70
|
+
:return origin_output_chunks: 切块后原始输出列表
|
|
71
|
+
:return perturbed_output_chunks: 切块后扰动后输出列表
|
|
72
|
+
"""
|
|
73
|
+
single_output_mem = (
|
|
74
|
+
origin_output.element_size() * origin_output.nelement() / Const.ONE_MB
|
|
75
|
+
)
|
|
76
|
+
if single_output_mem == 0 or origin_output.ndim == 0:
|
|
77
|
+
return [origin_output], [perturbed_output]
|
|
78
|
+
# 张量大小和批数之间的关系:chunks_exp=math.log(M,2)-4, chunks=2**chunks_exp (M为对比张量数据大小[Mb])
|
|
79
|
+
chunks_exp = int(math.log(single_output_mem, 2)) - 4
|
|
80
|
+
chunks = 2**chunks_exp
|
|
81
|
+
chunks = max(chunks, 1)
|
|
82
|
+
chunks = min(chunks, ThresholdConfig.TENSOR_SPLIT_MAX_CHUNK)
|
|
83
|
+
origin_output_chunks = TorchC.tensor_split(
|
|
84
|
+
TorchC.reshape(origin_output, (-1,)), chunks
|
|
85
|
+
)
|
|
86
|
+
perturbed_output_chunks = TorchC.tensor_split(
|
|
87
|
+
TorchC.reshape(perturbed_output, (-1,)), chunks
|
|
88
|
+
)
|
|
89
|
+
return origin_output_chunks, perturbed_output_chunks
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def convert_overflow_ratio_to_consistent(ratio):
|
|
93
|
+
if math.isnan(ratio) or math.isinf(ratio):
|
|
94
|
+
return ThresholdConfig.COMP_CONSISTENT
|
|
95
|
+
return ratio
|
|
96
|
+
|
|
97
|
+
@abstractmethod
|
|
98
|
+
def get_threshold(self, dtype):
|
|
99
|
+
pass
|
|
100
|
+
|
|
101
|
+
@abstractmethod
|
|
102
|
+
def handle(self, data_params: DataParams) -> Any:
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
def get_ratio_from_specific_norm(
|
|
106
|
+
self, origin_output, perturbed_output, norm_type, abs_tol
|
|
107
|
+
):
|
|
108
|
+
if norm_type == NormType.ENDLESS_NORM:
|
|
109
|
+
return self.calculate_error(origin_output, perturbed_output, abs_tol)
|
|
110
|
+
return ThresholdConfig.COMP_CONSISTENT
|
|
111
|
+
|
|
112
|
+
def calculate_error(self, origin_output, perturbed_output, abs_tol):
|
|
113
|
+
origin_output_chunks, perturbed_output_chunks = (
|
|
114
|
+
self.tensor_split_for_error_calculate(origin_output, perturbed_output)
|
|
115
|
+
)
|
|
116
|
+
norm1 = -np.inf
|
|
117
|
+
norm2 = -np.inf
|
|
118
|
+
norm3 = np.inf
|
|
119
|
+
for i, chunk_origin in enumerate(origin_output_chunks):
|
|
120
|
+
if chunk_origin.nelement() == 0:
|
|
121
|
+
break
|
|
122
|
+
chunk_perturbed = perturbed_output_chunks[i]
|
|
123
|
+
ratio_tensor1 = TorchC.where(
|
|
124
|
+
TorchC.abs(chunk_perturbed) > abs_tol,
|
|
125
|
+
TorchC.div(
|
|
126
|
+
TorchC.clamp(chunk_origin, min=abs_tol),
|
|
127
|
+
TorchC.clamp(chunk_perturbed, min=abs_tol),
|
|
128
|
+
),
|
|
129
|
+
1,
|
|
130
|
+
)
|
|
131
|
+
ratio_tensor2 = TorchC.where(
|
|
132
|
+
TorchC.abs(chunk_origin) > abs_tol,
|
|
133
|
+
TorchC.div(
|
|
134
|
+
TorchC.clamp(chunk_perturbed, min=abs_tol),
|
|
135
|
+
TorchC.clamp(chunk_origin, min=abs_tol),
|
|
136
|
+
),
|
|
137
|
+
1,
|
|
138
|
+
)
|
|
139
|
+
norm_values = TorchC.stack(
|
|
140
|
+
[TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)]
|
|
141
|
+
)
|
|
142
|
+
max_ratio1, max_ratio2 = norm_values.tolist()
|
|
143
|
+
norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1))
|
|
144
|
+
norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
|
|
145
|
+
norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1))
|
|
146
|
+
|
|
147
|
+
if norm3 < 0:
|
|
148
|
+
ratio = ThresholdConfig.SYMBOL_FLIPPING
|
|
149
|
+
else:
|
|
150
|
+
ratio = max(norm1, norm2)
|
|
151
|
+
return ratio
|
|
152
|
+
|
|
153
|
+
def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float:
|
|
154
|
+
try:
|
|
155
|
+
origin_output, perturbed_output, abs_tol = self.pre_process(
|
|
156
|
+
origin_output, perturbed_output
|
|
157
|
+
)
|
|
158
|
+
except Exception as e:
|
|
159
|
+
logger.warning_on_rank_0(
|
|
160
|
+
f"[msprobe] Free Benchmark: For {self.params.api_name}, "
|
|
161
|
+
f"when computing ratio,"
|
|
162
|
+
f" y1 or y2 dtype is not supported {e}"
|
|
163
|
+
)
|
|
164
|
+
return ThresholdConfig.COMP_NAN
|
|
165
|
+
if self.params.fuzz_stage == Const.BACKWARD:
|
|
166
|
+
abs_tol = ThresholdConfig.BACKWARD_OUTPUT_LOWER_BOUND
|
|
167
|
+
else:
|
|
168
|
+
abs_tol = abs_tol**0.5
|
|
169
|
+
return self.get_ratio_from_specific_norm(
|
|
170
|
+
origin_output, perturbed_output, norm_type, abs_tol
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def npu_compare(
|
|
174
|
+
self, origin_output, perturbed_output
|
|
175
|
+
) -> Tuple[bool, Optional[float]]:
|
|
176
|
+
|
|
177
|
+
if isinstance(perturbed_output, int):
|
|
178
|
+
return origin_output == perturbed_output, None
|
|
179
|
+
elif isinstance(perturbed_output, float):
|
|
180
|
+
if perturbed_output == 0:
|
|
181
|
+
origin_output += FuzzThreshold.F32_THD
|
|
182
|
+
perturbed_output += FuzzThreshold.F32_THD
|
|
183
|
+
return (
|
|
184
|
+
math.isclose(origin_output, perturbed_output),
|
|
185
|
+
origin_output / perturbed_output,
|
|
186
|
+
)
|
|
187
|
+
elif not isinstance(perturbed_output, torch.Tensor):
|
|
188
|
+
logger.warning_on_rank_0(
|
|
189
|
+
f"[msprobe] Free Benchmark: For {self.params.api_name} "
|
|
190
|
+
f"The compare for output type {type(perturbed_output)} is not supported"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
threshold = self.get_threshold(Tools.get_first_tensor_dtype(origin_output))
|
|
194
|
+
ratio = self.ratio_calculate(
|
|
195
|
+
origin_output, perturbed_output, norm_type=NormType.ENDLESS_NORM
|
|
196
|
+
)
|
|
197
|
+
if ratio == ThresholdConfig.SYMBOL_FLIPPING:
|
|
198
|
+
is_consistent = False
|
|
199
|
+
else:
|
|
200
|
+
is_consistent = threshold >= ratio >= 1 / threshold
|
|
201
|
+
return is_consistent, ratio
|
|
202
|
+
|
|
203
|
+
def cmp_output_npu(self, data_params: DataParams):
|
|
204
|
+
npu_consistent = True
|
|
205
|
+
max_fuzz_ratio = 0
|
|
206
|
+
try:
|
|
207
|
+
if isinstance(data_params.original_result, torch.Tensor):
|
|
208
|
+
is_consistent, ratio = self.npu_compare(
|
|
209
|
+
data_params.original_result, data_params.perturbed_result
|
|
210
|
+
)
|
|
211
|
+
npu_consistent = is_consistent
|
|
212
|
+
max_fuzz_ratio = (
|
|
213
|
+
max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
|
|
214
|
+
)
|
|
215
|
+
data_params.is_consistent = is_consistent and data_params.is_consistent
|
|
216
|
+
if not is_consistent and data_params.grad_unequal_flag:
|
|
217
|
+
self.unequal_rows.append(
|
|
218
|
+
make_unequal_row(data_params, self.params, ratio=ratio)
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
elif isinstance(data_params.original_result, (list, tuple)):
|
|
222
|
+
for index_, origin_item in enumerate(data_params.original_result):
|
|
223
|
+
is_consistent, ratio = self.npu_compare(
|
|
224
|
+
origin_item, data_params.perturbed_result[index_]
|
|
225
|
+
)
|
|
226
|
+
npu_consistent = npu_consistent and is_consistent
|
|
227
|
+
max_fuzz_ratio = (
|
|
228
|
+
max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
|
|
229
|
+
)
|
|
230
|
+
data_params.is_consistent = (
|
|
231
|
+
is_consistent and data_params.is_consistent
|
|
232
|
+
)
|
|
233
|
+
if not is_consistent and data_params.grad_unequal_flag:
|
|
234
|
+
self.unequal_rows.append(
|
|
235
|
+
make_unequal_row(
|
|
236
|
+
data_params, self.params, ratio=ratio, index=index_
|
|
237
|
+
)
|
|
238
|
+
)
|
|
239
|
+
except Exception as e:
|
|
240
|
+
logger.warning_on_rank_0(
|
|
241
|
+
f"[msprobe] Free Benchmark: For {self.params.api_name}, "
|
|
242
|
+
f"when campare the result exception raise {e}"
|
|
243
|
+
)
|
|
244
|
+
return npu_consistent, max_fuzz_ratio
|
|
245
|
+
|
|
246
|
+
def get_unequal_rows(self):
|
|
247
|
+
return self.unequal_rows
|
|
248
|
+
|
|
249
|
+
def _get_default_threshold(self, dtype):
|
|
250
|
+
if self.params.pert_mode == PerturbationMode.NO_CHANGE:
|
|
251
|
+
threshold = ThresholdConfig.COMP_CONSISTENT
|
|
252
|
+
else:
|
|
253
|
+
threshold = ThresholdConfig.DTYPE_PER_THD.get(
|
|
254
|
+
dtype, ThresholdConfig.DTYPE_PER_THD.get(torch.float32)
|
|
255
|
+
)
|
|
256
|
+
return threshold
|