mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright 2024 Huawei Technologies Co., Ltd
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
#include "hook_dynamic_loader.h"
|
|
18
|
+
#include <sys/stat.h>
|
|
19
|
+
#include <cstdlib>
|
|
20
|
+
#include <cstring>
|
|
21
|
+
#include "utils/log_adapter.h"
|
|
22
|
+
|
|
23
|
+
namespace {
|
|
24
|
+
|
|
25
|
+
// Utility function to check if a file path is valid
|
|
26
|
+
bool IsValidPath(const std::string &path) {
|
|
27
|
+
struct stat fileStat;
|
|
28
|
+
if (stat(path.c_str(), &fileStat) != 0) {
|
|
29
|
+
MS_LOG(ERROR) << "File does not exist or cannot be accessed: " << path;
|
|
30
|
+
return false;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
if (S_ISLNK(fileStat.st_mode)) {
|
|
34
|
+
MS_LOG(ERROR) << "File is a symbolic link, which is not allowed: " << path;
|
|
35
|
+
return false;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
if (!S_ISREG(fileStat.st_mode)) {
|
|
39
|
+
MS_LOG(ERROR) << "File is not a regular file: " << path;
|
|
40
|
+
return false;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
if (path.substr(path.find_last_of(".")) != ".so") {
|
|
44
|
+
MS_LOG(ERROR) << "File is not a .so file: " << path;
|
|
45
|
+
return false;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
return true;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
} // namespace
|
|
52
|
+
|
|
53
|
+
HookDynamicLoader &HookDynamicLoader::GetInstance() {
|
|
54
|
+
static HookDynamicLoader instance;
|
|
55
|
+
return instance;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
bool HookDynamicLoader::loadFunction(void *handle, const std::string &functionName) {
|
|
59
|
+
void *func = dlsym(handle, functionName.c_str());
|
|
60
|
+
if (!func) {
|
|
61
|
+
MS_LOG(WARNING) << "Could not load function: " << functionName << ", error: " << dlerror();
|
|
62
|
+
return false;
|
|
63
|
+
}
|
|
64
|
+
funcMap_[functionName] = func;
|
|
65
|
+
return true;
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
bool HookDynamicLoader::validateLibraryPath(const std::string &libPath) {
|
|
69
|
+
char *realPath = realpath(libPath.c_str(), nullptr);
|
|
70
|
+
if (!realPath) {
|
|
71
|
+
MS_LOG(WARNING) << "Failed to resolve realpath for the library: " << libPath;
|
|
72
|
+
return false;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
bool isValid = IsValidPath(realPath);
|
|
76
|
+
free(realPath); // Free memory allocated by realpath
|
|
77
|
+
return isValid;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
bool HookDynamicLoader::LoadLibrary() {
|
|
81
|
+
const char *libPath = std::getenv("HOOK_TOOL_PATH");
|
|
82
|
+
if (!libPath) {
|
|
83
|
+
MS_LOG(WARNING) << "HOOK_TOOL_PATH is not set!";
|
|
84
|
+
return false;
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
std::string resolvedLibPath(libPath);
|
|
88
|
+
if (!validateLibraryPath(resolvedLibPath)) {
|
|
89
|
+
MS_LOG(WARNING) << "Library path validation failed.";
|
|
90
|
+
return false;
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
|
94
|
+
if (handle_) {
|
|
95
|
+
MS_LOG(WARNING) << "Hook library already loaded!";
|
|
96
|
+
return false;
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
handle_ = dlopen(resolvedLibPath.c_str(), RTLD_LAZY | RTLD_LOCAL);
|
|
100
|
+
if (!handle_) {
|
|
101
|
+
MS_LOG(WARNING) << "Failed to load Hook library: " << dlerror();
|
|
102
|
+
return false;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
for (const auto &functionName : functionList_) {
|
|
106
|
+
if (!loadFunction(handle_, functionName)) {
|
|
107
|
+
MS_LOG(WARNING) << "Failed to load function: " << functionName;
|
|
108
|
+
dlclose(handle_);
|
|
109
|
+
handle_ = nullptr;
|
|
110
|
+
return false;
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
MS_LOG(INFO) << "Hook library loaded successfully.";
|
|
115
|
+
return true;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
bool HookDynamicLoader::UnloadLibrary() {
|
|
119
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
|
120
|
+
if (!handle_) {
|
|
121
|
+
MS_LOG(WARNING) << "Hook library hasn't been loaded.";
|
|
122
|
+
return false;
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
dlclose(handle_);
|
|
126
|
+
handle_ = nullptr;
|
|
127
|
+
funcMap_.clear();
|
|
128
|
+
MS_LOG(INFO) << "Library unloaded successfully.";
|
|
129
|
+
return true;
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
void *HookDynamicLoader::GetHooker(const std::string &funcName) {
|
|
133
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
|
134
|
+
auto iter = funcMap_.find(funcName);
|
|
135
|
+
if (iter == funcMap_.end()) {
|
|
136
|
+
MS_LOG(WARNING) << "Function not found: " << funcName;
|
|
137
|
+
return nullptr;
|
|
138
|
+
}
|
|
139
|
+
return iter->second;
|
|
140
|
+
}
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright 2024 Huawei Technologies Co., Ltd
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
#ifndef HOOK_DYNAMIC_LOADER_H
|
|
18
|
+
#define HOOK_DYNAMIC_LOADER_H
|
|
19
|
+
|
|
20
|
+
#include <dlfcn.h>
|
|
21
|
+
#include <string>
|
|
22
|
+
#include <vector>
|
|
23
|
+
#include <map>
|
|
24
|
+
#include <mutex>
|
|
25
|
+
|
|
26
|
+
constexpr auto kHookBegin = "MS_DbgOnStepBegin";
|
|
27
|
+
constexpr auto kHookEnd = "MS_DbgOnStepEnd";
|
|
28
|
+
|
|
29
|
+
class HookDynamicLoader {
|
|
30
|
+
public:
|
|
31
|
+
static HookDynamicLoader &GetInstance();
|
|
32
|
+
|
|
33
|
+
HookDynamicLoader(const HookDynamicLoader &) = delete;
|
|
34
|
+
HookDynamicLoader &operator=(const HookDynamicLoader &) = delete;
|
|
35
|
+
|
|
36
|
+
bool LoadLibrary();
|
|
37
|
+
bool UnloadLibrary();
|
|
38
|
+
void *GetHooker(const std::string &funcName);
|
|
39
|
+
|
|
40
|
+
private:
|
|
41
|
+
// Helper functions
|
|
42
|
+
bool loadFunction(void *handle, const std::string &functionName);
|
|
43
|
+
bool validateLibraryPath(const std::string &libPath);
|
|
44
|
+
|
|
45
|
+
HookDynamicLoader() = default;
|
|
46
|
+
|
|
47
|
+
void *handle_ = nullptr;
|
|
48
|
+
std::vector<std::string> functionList_ = {kHookBegin, kHookEnd};
|
|
49
|
+
std::map<std::string, void *> funcMap_;
|
|
50
|
+
std::mutex mutex_;
|
|
51
|
+
};
|
|
52
|
+
|
|
53
|
+
#endif // HOOK_DYNAMIC_LOADER_H
|
|
@@ -1,21 +1,43 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import functools
|
|
3
17
|
import importlib
|
|
18
|
+
import os
|
|
19
|
+
import traceback
|
|
4
20
|
|
|
5
21
|
import mindspore as ms
|
|
6
|
-
from mindspore.communication import comm_func
|
|
7
22
|
|
|
8
|
-
from msprobe.core.common.file_utils import load_yaml, check_path_length
|
|
9
23
|
from msprobe.core.common.const import Const
|
|
24
|
+
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
25
|
+
from msprobe.core.common.file_utils import check_path_length, load_yaml
|
|
10
26
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
11
27
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
12
|
-
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
13
28
|
from msprobe.mindspore.common.log import logger
|
|
29
|
+
from msprobe.mindspore.common.utils import get_rank_if_initialized
|
|
14
30
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
15
|
-
from msprobe.mindspore.
|
|
31
|
+
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
32
|
+
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
33
|
+
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
34
|
+
from msprobe.mindspore.free_benchmark.common.utils import Tools
|
|
35
|
+
from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory
|
|
36
|
+
from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory
|
|
37
|
+
from msprobe.mindspore.runtime import Runtime
|
|
16
38
|
|
|
17
39
|
|
|
18
|
-
class
|
|
40
|
+
class ApiPyNativeSelfCheck:
|
|
19
41
|
def __init__(self, config: DebuggerConfig):
|
|
20
42
|
Config.is_enable = True
|
|
21
43
|
Config.handler_type = config.handler_type
|
|
@@ -24,29 +46,68 @@ class ApiPyNativeSelFCheck:
|
|
|
24
46
|
Config.dump_level = config.dump_level
|
|
25
47
|
Config.steps = config.step
|
|
26
48
|
Config.ranks = config.rank
|
|
27
|
-
Config.dump_path = os.path.join(config.dump_path,
|
|
49
|
+
Config.dump_path = os.path.join(config.dump_path, FreeBenchmarkConst.CHECK_RESULT_FILE)
|
|
28
50
|
check_path_length(Config.dump_path)
|
|
29
51
|
|
|
52
|
+
self.ori_func = {}
|
|
53
|
+
|
|
30
54
|
self.api_list = config.list
|
|
31
55
|
all_api = get_supported_ops()
|
|
32
56
|
if not self.api_list:
|
|
33
57
|
self.api_list = all_api
|
|
34
58
|
else:
|
|
35
59
|
self.api_list = set(self.api_list) & all_api
|
|
60
|
+
self.store_original_func()
|
|
36
61
|
|
|
37
62
|
def handle(self):
|
|
63
|
+
api_register.initialize_hook(self.build_hook)
|
|
64
|
+
api_register.api_set_hook_func()
|
|
65
|
+
|
|
66
|
+
def build_hook(self, api_name_with_id):
|
|
67
|
+
def forward_hook(api_name_with_id, cell, input_data, output_data):
|
|
68
|
+
ret = None
|
|
69
|
+
|
|
70
|
+
if not need_wrapper_func():
|
|
71
|
+
del cell.input_kwargs
|
|
72
|
+
return ret
|
|
73
|
+
|
|
74
|
+
api_name_with_id = api_name_with_id[:-1]
|
|
75
|
+
hook_prefix = api_name_with_id[:api_name_with_id.find(Const.SEP) + 1]
|
|
76
|
+
api_name = (MsConst.HOOK_MS_PREFIX_DICT.get(hook_prefix, "") +
|
|
77
|
+
api_name_with_id[api_name_with_id.find(Const.SEP) + 1:api_name_with_id.rfind(Const.SEP)])
|
|
78
|
+
if api_name in self.api_list:
|
|
79
|
+
ret = check_self(api_name_with_id, output_data, self.ori_func.get(api_name),
|
|
80
|
+
*input_data, **cell.input_kwargs)
|
|
81
|
+
|
|
82
|
+
del cell.input_kwargs
|
|
83
|
+
return ret
|
|
84
|
+
|
|
85
|
+
def backward_hook(cell, grad_input, grad_output):
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
forward_hook = functools.partial(forward_hook, api_name_with_id)
|
|
89
|
+
|
|
90
|
+
def wrap_forward_hook(cell, input_data, output_data):
|
|
91
|
+
return forward_hook(cell, input_data, output_data)
|
|
92
|
+
|
|
93
|
+
def wrap_backward_hook(cell, grad_input, grad_output):
|
|
94
|
+
return backward_hook(cell, grad_input, grad_output)
|
|
95
|
+
|
|
96
|
+
return wrap_forward_hook, wrap_backward_hook
|
|
97
|
+
|
|
98
|
+
def store_original_func(self):
|
|
38
99
|
for api_name in self.api_list:
|
|
39
|
-
|
|
100
|
+
self.ori_func[api_name] = get_module(api_name)[1]
|
|
40
101
|
|
|
41
102
|
|
|
42
103
|
def get_supported_ops():
|
|
43
104
|
supported_ops = []
|
|
44
105
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
45
|
-
yaml_path = os.path.join(cur_path, "data",
|
|
106
|
+
yaml_path = os.path.join(cur_path, "data", FreeBenchmarkConst.SUPPORTED_CHECK_API_FILE)
|
|
46
107
|
|
|
47
|
-
|
|
108
|
+
supported_ops_list = load_yaml(yaml_path)
|
|
48
109
|
for k, v in FreeBenchmarkConst.API_PREFIX_DICT.items():
|
|
49
|
-
ops =
|
|
110
|
+
ops = supported_ops_list.get(k)
|
|
50
111
|
if ops:
|
|
51
112
|
ops = [v + i for i in ops]
|
|
52
113
|
supported_ops += ops
|
|
@@ -57,7 +118,7 @@ def get_supported_ops():
|
|
|
57
118
|
_all_functional_ops += ms_ops
|
|
58
119
|
|
|
59
120
|
ms_tensor = dir(ms.Tensor)
|
|
60
|
-
ms_tensor = [MsConst.
|
|
121
|
+
ms_tensor = [MsConst.TENSOR_PREFIX + i for i in ms_tensor]
|
|
61
122
|
_all_functional_ops += ms_tensor
|
|
62
123
|
|
|
63
124
|
ms_mint = dir(ms.mint)
|
|
@@ -68,29 +129,9 @@ def get_supported_ops():
|
|
|
68
129
|
ms_mint_nn_func = [MsConst.MINT_NN_FUNC_PREFIX + i for i in ms_mint_nn_func]
|
|
69
130
|
_all_functional_ops += ms_mint_nn_func
|
|
70
131
|
|
|
71
|
-
ms_communication = dir(comm_func)
|
|
72
|
-
ms_communication = [MsConst.COMM_PREFIX + i for i in ms_communication]
|
|
73
|
-
_all_functional_ops += ms_communication
|
|
74
|
-
|
|
75
132
|
return set(supported_ops) & set(_all_functional_ops)
|
|
76
133
|
|
|
77
134
|
|
|
78
|
-
def get_decorate_func():
|
|
79
|
-
return decorate_forward_function
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
def is_func_support_decorate(orig_func):
|
|
83
|
-
return not inspect.isclass(orig_func) and callable(orig_func)
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
def get_wrapper_obj(orig_func, api_name):
|
|
87
|
-
if is_func_support_decorate(orig_func):
|
|
88
|
-
wrapped_obj = get_decorate_func()(orig_func, api_name)
|
|
89
|
-
else:
|
|
90
|
-
wrapped_obj = orig_func
|
|
91
|
-
return wrapped_obj
|
|
92
|
-
|
|
93
|
-
|
|
94
135
|
def get_module(api_name):
|
|
95
136
|
func_name_list = api_name.split(Const.SEP)
|
|
96
137
|
func_name = func_name_list[-1]
|
|
@@ -104,13 +145,93 @@ def get_module(api_name):
|
|
|
104
145
|
return module_obj, orig_func
|
|
105
146
|
|
|
106
147
|
|
|
107
|
-
def
|
|
108
|
-
|
|
109
|
-
|
|
148
|
+
def check_self(api_name_with_id, output, ori_func, *args, **kwargs):
|
|
149
|
+
ret = None
|
|
150
|
+
|
|
151
|
+
if Config.stage == Const.BACKWARD and not (check_all_tensor(args) and check_all_tensor(output)):
|
|
152
|
+
logger.warning(f"{api_name_with_id} has non-tensor input or output.")
|
|
153
|
+
return ret
|
|
154
|
+
|
|
155
|
+
params = data_pre_deal(api_name_with_id, ori_func, *args, **kwargs)
|
|
156
|
+
if params.index == -1:
|
|
157
|
+
return ret
|
|
158
|
+
|
|
159
|
+
logger.info(f"[{api_name_with_id}] is {Config.handler_type}ing.")
|
|
160
|
+
api_register.api_set_ori_func()
|
|
161
|
+
|
|
110
162
|
try:
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
163
|
+
perturbation = PerturbationFactory.create(api_name_with_id)
|
|
164
|
+
params.fuzzed_result = perturbation.handle(params)
|
|
165
|
+
if params.fuzzed_result is False:
|
|
166
|
+
api_register.api_set_hook_func()
|
|
167
|
+
return ret
|
|
168
|
+
if Config.stage == Const.BACKWARD:
|
|
169
|
+
params.original_result = Tools.get_grad(params.original_func, *params.args, **params.kwargs)
|
|
170
|
+
else:
|
|
171
|
+
params.original_result = output
|
|
172
|
+
ret = deal_fuzzed_and_original_result(api_name_with_id, params)
|
|
115
173
|
except Exception as e:
|
|
116
|
-
logger.error(f"
|
|
174
|
+
logger.error(f"[{api_name_with_id}] Error: {str(e)}")
|
|
175
|
+
logger.error(f"[{api_name_with_id}] Error detail: {traceback.format_exc()}")
|
|
176
|
+
|
|
177
|
+
api_register.api_set_hook_func()
|
|
178
|
+
return ret
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def check_all_tensor(input_output):
|
|
182
|
+
if isinstance(input_output, ms.Tensor):
|
|
183
|
+
return True
|
|
184
|
+
if isinstance(input_output, (tuple, list)):
|
|
185
|
+
return all([check_all_tensor(v) for v in input_output])
|
|
186
|
+
return False
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def get_target_arg_index(args) -> int:
|
|
190
|
+
"""
|
|
191
|
+
类型校验
|
|
192
|
+
|
|
193
|
+
"""
|
|
194
|
+
for i, arg in enumerate(args):
|
|
195
|
+
if ms.ops.is_tensor(arg):
|
|
196
|
+
if not ms.ops.is_floating_point(arg):
|
|
197
|
+
continue
|
|
198
|
+
return i
|
|
199
|
+
if isinstance(arg, (list, tuple, dict)):
|
|
200
|
+
return i
|
|
201
|
+
return -1
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def data_pre_deal(api_name_with_id, func, *args, **kwargs):
|
|
205
|
+
params = HandlerParams()
|
|
206
|
+
params.args = args
|
|
207
|
+
params.kwargs = kwargs
|
|
208
|
+
params.original_func = func
|
|
209
|
+
index = get_target_arg_index(args)
|
|
210
|
+
if index == -1:
|
|
211
|
+
logger.warning(f"{api_name_with_id} has no supported input type.")
|
|
212
|
+
params.index = index
|
|
213
|
+
return params
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def need_wrapper_func():
|
|
217
|
+
if not (Runtime.is_running and Config.is_enable):
|
|
218
|
+
return False
|
|
219
|
+
|
|
220
|
+
if Config.steps and Runtime.step_count not in Config.steps:
|
|
221
|
+
return False
|
|
222
|
+
|
|
223
|
+
if Runtime.rank_id == -1:
|
|
224
|
+
try:
|
|
225
|
+
Runtime.rank_id = get_rank_if_initialized()
|
|
226
|
+
except DistributedNotInitializedError:
|
|
227
|
+
Runtime.rank_id = -1
|
|
228
|
+
if Config.ranks and Runtime.rank_id != -1 and Runtime.rank_id not in Config.ranks:
|
|
229
|
+
return False
|
|
230
|
+
|
|
231
|
+
return True
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def deal_fuzzed_and_original_result(api_name_with_id, params: HandlerParams):
|
|
235
|
+
handler = HandlerFactory.create(api_name_with_id)
|
|
236
|
+
result = handler.handle(params)
|
|
237
|
+
return result
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
2
17
|
|
|
3
18
|
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Optional, Any, Tuple, Dict, Callable
|
|
2
17
|
|
|
3
18
|
|
|
@@ -12,6 +27,5 @@ class HandlerParams:
|
|
|
12
27
|
original_result: Optional[Any] = None
|
|
13
28
|
fuzzed_result: Optional[Any] = None
|
|
14
29
|
is_consistent: Optional[bool] = True
|
|
15
|
-
save_flag: Optional[bool] = True
|
|
16
30
|
fuzzed_value: Optional[Any] = None
|
|
17
31
|
original_func: Optional[Callable] = None
|
|
@@ -1,14 +1,28 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
3
16
|
from dataclasses import dataclass
|
|
17
|
+
from typing import Any, Optional
|
|
4
18
|
|
|
5
19
|
import mindspore as ms
|
|
6
|
-
from mindspore import Tensor
|
|
20
|
+
from mindspore import Tensor, ops
|
|
7
21
|
|
|
8
|
-
from msprobe.mindspore.runtime import Runtime
|
|
9
22
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
10
|
-
from .config import Config
|
|
11
|
-
from .handler_params import HandlerParams
|
|
23
|
+
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
24
|
+
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
25
|
+
from msprobe.mindspore.runtime import Runtime
|
|
12
26
|
|
|
13
27
|
|
|
14
28
|
class Tools:
|
|
@@ -29,6 +43,23 @@ class Tools:
|
|
|
29
43
|
return FreeBenchmarkConst.NO_CHANGE_ERROR_THRESHOLD
|
|
30
44
|
return FreeBenchmarkConst.ERROR_THRESHOLD.get(dtype, FreeBenchmarkConst.ERROR_THRESHOLD.get(ms.float32))
|
|
31
45
|
|
|
46
|
+
@staticmethod
|
|
47
|
+
def get_grad_out(outputs):
|
|
48
|
+
if isinstance(outputs, Tensor):
|
|
49
|
+
return ops.ones_like(outputs)
|
|
50
|
+
if isinstance(outputs, (tuple, list)):
|
|
51
|
+
return type(outputs)([Tools.get_grad_out(v) for v in outputs])
|
|
52
|
+
return outputs
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def get_grad(func, *args, **kwargs):
|
|
56
|
+
def target_func(*inputs):
|
|
57
|
+
return func(*inputs, **kwargs)
|
|
58
|
+
|
|
59
|
+
outputs, vjp_fn = ms.vjp(target_func, *args)
|
|
60
|
+
values = Tools.get_grad_out(outputs)
|
|
61
|
+
return vjp_fn(values)
|
|
62
|
+
|
|
32
63
|
|
|
33
64
|
@dataclass
|
|
34
65
|
class UnequalRow:
|
|
@@ -59,10 +90,8 @@ def make_unequal_row(
|
|
|
59
90
|
if isinstance(ratio, float):
|
|
60
91
|
row.max_rel = ratio - 1
|
|
61
92
|
original_tensor = params.original_result
|
|
62
|
-
fuzzed_tensor = params.fuzzed_result
|
|
63
93
|
if index is not None:
|
|
64
94
|
original_tensor = original_tensor[index]
|
|
65
|
-
fuzzed_tensor = fuzzed_tensor[index]
|
|
66
95
|
row.output_index = index
|
|
67
96
|
if isinstance(original_tensor, Tensor):
|
|
68
97
|
row.dtype = original_tensor.dtype
|