mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -20,7 +20,7 @@ from mindspore import Tensor
|
|
|
20
20
|
from mindspore._c_expression import PyNativeExecutor_
|
|
21
21
|
from mindspore.common.api import _MindsporeFunctionExecutor
|
|
22
22
|
|
|
23
|
-
from msprobe.
|
|
23
|
+
from msprobe.core.common.log import logger
|
|
24
24
|
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
25
25
|
from msprobe.core.common.const import Const
|
|
26
26
|
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs
|
|
@@ -33,6 +33,8 @@ def dump_jit(name, in_feat, out_feat, is_forward):
|
|
|
33
33
|
index = ori_args.find("<")
|
|
34
34
|
if index != 0 and index != -1:
|
|
35
35
|
result = ori_args[0:index]
|
|
36
|
+
elif name is not None and "<" not in str(name):
|
|
37
|
+
result = str(name)
|
|
36
38
|
else:
|
|
37
39
|
result = "JitFunction"
|
|
38
40
|
if JitDump.need_dump():
|
|
@@ -47,7 +49,7 @@ def dump_jit(name, in_feat, out_feat, is_forward):
|
|
|
47
49
|
name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \
|
|
48
50
|
Const.BACKWARD
|
|
49
51
|
JitDump.data_collector.update_api_or_module_name(name_template)
|
|
50
|
-
module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat
|
|
52
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat, grad_output=out_feat)
|
|
51
53
|
JitDump.data_collector.backward_data_collect(name_template, None, pid, module_input_output)
|
|
52
54
|
|
|
53
55
|
|
|
@@ -59,15 +61,25 @@ class JitDump(_MindsporeFunctionExecutor):
|
|
|
59
61
|
|
|
60
62
|
def __init__(self, *args, **kwargs):
|
|
61
63
|
super().__init__(*args, **kwargs)
|
|
64
|
+
self.name = None
|
|
65
|
+
if len(args) > 0:
|
|
66
|
+
self.name = args[0].__name__
|
|
62
67
|
self._executor = PyNativeExecutor_.get_instance()
|
|
63
68
|
|
|
64
69
|
def __call__(self, *args, **kwargs):
|
|
65
|
-
|
|
70
|
+
if JitDump.jit_dump_switch:
|
|
71
|
+
api_register.api_set_ori_func()
|
|
66
72
|
out = super().__call__(*args, **kwargs)
|
|
67
73
|
if JitDump.jit_dump_switch and len(args) > 0:
|
|
68
|
-
|
|
74
|
+
if self.name and self.name != "construct":
|
|
75
|
+
dump_jit(self.name, args, out, True)
|
|
76
|
+
else:
|
|
77
|
+
dump_jit(args[0], args, out, True)
|
|
69
78
|
JitDump.jit_enable = True
|
|
70
|
-
|
|
79
|
+
elif len(args) == 0:
|
|
80
|
+
logger.warning(f"The jit function {self.name} has no input arguments, nothing will be dumped.")
|
|
81
|
+
if JitDump.jit_dump_switch:
|
|
82
|
+
api_register.api_set_hook_func()
|
|
71
83
|
return out
|
|
72
84
|
|
|
73
85
|
@classmethod
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.file_utils import save_json
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_kernel_config_json(dump_path, cur_rank):
|
|
22
|
+
kernel_config_name = "kernel_config.json" if cur_rank == '' else f"kernel_config_{cur_rank}.json"
|
|
23
|
+
kernel_config_path = os.path.join(dump_path, kernel_config_name)
|
|
24
|
+
config_info = {
|
|
25
|
+
"dump": {
|
|
26
|
+
"dump_list": [],
|
|
27
|
+
"dump_path": dump_path,
|
|
28
|
+
"dump_mode": "all",
|
|
29
|
+
"dump_op_switch": "on"
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
save_json(kernel_config_path, config_info, indent=4)
|
|
33
|
+
return kernel_config_path
|
|
@@ -13,10 +13,9 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import json
|
|
17
16
|
import os
|
|
18
17
|
|
|
19
|
-
from msprobe.core.common.file_utils import
|
|
18
|
+
from msprobe.core.common.file_utils import create_directory, save_json
|
|
20
19
|
from msprobe.mindspore.common.log import logger
|
|
21
20
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
22
21
|
|
|
@@ -57,13 +56,19 @@ class KernelGraphDump:
|
|
|
57
56
|
self.dump_json["common_dump_settings"]["input_output"] = 2
|
|
58
57
|
|
|
59
58
|
def handle(self):
|
|
59
|
+
try:
|
|
60
|
+
from msprobe.lib import _msprobe_c
|
|
61
|
+
return
|
|
62
|
+
except ImportError:
|
|
63
|
+
# 如果没有_msprobe_ce_c走MindSpore老流程
|
|
64
|
+
logger.info("Module _msprobe_c has not been installed, use interface in mindspore instead.")
|
|
65
|
+
|
|
60
66
|
if os.getenv("GRAPH_OP_RUN") == "1":
|
|
61
67
|
raise Exception("Must run in graph mode, not kbk mode")
|
|
62
68
|
json_path = self.dump_json["common_dump_settings"]["path"]
|
|
63
69
|
create_directory(json_path)
|
|
64
70
|
json_path = os.path.join(json_path, "kernel_graph_dump.json")
|
|
65
|
-
|
|
66
|
-
json.dump(self.dump_json, f)
|
|
71
|
+
save_json(json_path, self.dump_json, indent=4)
|
|
67
72
|
logger.info(json_path + " has been created.")
|
|
68
73
|
os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
|
|
69
74
|
if self.dump_json["common_dump_settings"]["dump_mode"] == 0:
|
|
@@ -13,11 +13,10 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import json
|
|
17
16
|
import os
|
|
18
17
|
|
|
19
18
|
from msprobe.core.common.const import Const
|
|
20
|
-
from msprobe.core.common.file_utils import
|
|
19
|
+
from msprobe.core.common.file_utils import create_directory, save_json
|
|
21
20
|
from msprobe.mindspore.common.log import logger
|
|
22
21
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
23
22
|
|
|
@@ -70,8 +69,7 @@ class KernelKbykDump:
|
|
|
70
69
|
json_path = self.dump_json[KernelKbykDump.COMMON_SETTINGS]["path"]
|
|
71
70
|
create_directory(json_path)
|
|
72
71
|
json_path = os.path.join(json_path, "kernel_kbyk_dump.json")
|
|
73
|
-
|
|
74
|
-
json.dump(self.dump_json, f)
|
|
72
|
+
save_json(json_path, self.dump_json, indent=4)
|
|
75
73
|
logger.info(json_path + " has been created.")
|
|
76
74
|
|
|
77
75
|
os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright 2024 Huawei Technologies Co., Ltd
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
#include "hook_dynamic_loader.h"
|
|
18
|
+
#include <sys/stat.h>
|
|
19
|
+
#include <cstdlib>
|
|
20
|
+
#include <cstring>
|
|
21
|
+
#include "utils/log_adapter.h"
|
|
22
|
+
|
|
23
|
+
namespace {
|
|
24
|
+
|
|
25
|
+
// Utility function to check if a file path is valid
|
|
26
|
+
bool IsValidPath(const std::string &path) {
|
|
27
|
+
struct stat fileStat;
|
|
28
|
+
if (stat(path.c_str(), &fileStat) != 0) {
|
|
29
|
+
MS_LOG(ERROR) << "File does not exist or cannot be accessed: " << path;
|
|
30
|
+
return false;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
if (S_ISLNK(fileStat.st_mode)) {
|
|
34
|
+
MS_LOG(ERROR) << "File is a symbolic link, which is not allowed: " << path;
|
|
35
|
+
return false;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
if (!S_ISREG(fileStat.st_mode)) {
|
|
39
|
+
MS_LOG(ERROR) << "File is not a regular file: " << path;
|
|
40
|
+
return false;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
if (path.substr(path.find_last_of(".")) != ".so") {
|
|
44
|
+
MS_LOG(ERROR) << "File is not a .so file: " << path;
|
|
45
|
+
return false;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
return true;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
} // namespace
|
|
52
|
+
|
|
53
|
+
HookDynamicLoader &HookDynamicLoader::GetInstance() {
|
|
54
|
+
static HookDynamicLoader instance;
|
|
55
|
+
return instance;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
bool HookDynamicLoader::loadFunction(void *handle, const std::string &functionName) {
|
|
59
|
+
void *func = dlsym(handle, functionName.c_str());
|
|
60
|
+
if (!func) {
|
|
61
|
+
MS_LOG(WARNING) << "Could not load function: " << functionName << ", error: " << dlerror();
|
|
62
|
+
return false;
|
|
63
|
+
}
|
|
64
|
+
funcMap_[functionName] = func;
|
|
65
|
+
return true;
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
bool HookDynamicLoader::validateLibraryPath(const std::string &libPath) {
|
|
69
|
+
char *realPath = realpath(libPath.c_str(), nullptr);
|
|
70
|
+
if (!realPath) {
|
|
71
|
+
MS_LOG(WARNING) << "Failed to resolve realpath for the library: " << libPath;
|
|
72
|
+
return false;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
bool isValid = IsValidPath(realPath);
|
|
76
|
+
free(realPath); // Free memory allocated by realpath
|
|
77
|
+
return isValid;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
bool HookDynamicLoader::LoadLibrary() {
|
|
81
|
+
const char *libPath = std::getenv("HOOK_TOOL_PATH");
|
|
82
|
+
if (!libPath) {
|
|
83
|
+
MS_LOG(WARNING) << "HOOK_TOOL_PATH is not set!";
|
|
84
|
+
return false;
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
std::string resolvedLibPath(libPath);
|
|
88
|
+
if (!validateLibraryPath(resolvedLibPath)) {
|
|
89
|
+
MS_LOG(WARNING) << "Library path validation failed.";
|
|
90
|
+
return false;
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
|
94
|
+
if (handle_) {
|
|
95
|
+
MS_LOG(WARNING) << "Hook library already loaded!";
|
|
96
|
+
return false;
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
handle_ = dlopen(resolvedLibPath.c_str(), RTLD_LAZY | RTLD_LOCAL);
|
|
100
|
+
if (!handle_) {
|
|
101
|
+
MS_LOG(WARNING) << "Failed to load Hook library: " << dlerror();
|
|
102
|
+
return false;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
for (const auto &functionName : functionList_) {
|
|
106
|
+
if (!loadFunction(handle_, functionName)) {
|
|
107
|
+
MS_LOG(WARNING) << "Failed to load function: " << functionName;
|
|
108
|
+
dlclose(handle_);
|
|
109
|
+
handle_ = nullptr;
|
|
110
|
+
return false;
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
MS_LOG(INFO) << "Hook library loaded successfully.";
|
|
115
|
+
return true;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
bool HookDynamicLoader::UnloadLibrary() {
|
|
119
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
|
120
|
+
if (!handle_) {
|
|
121
|
+
MS_LOG(WARNING) << "Hook library hasn't been loaded.";
|
|
122
|
+
return false;
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
dlclose(handle_);
|
|
126
|
+
handle_ = nullptr;
|
|
127
|
+
funcMap_.clear();
|
|
128
|
+
MS_LOG(INFO) << "Library unloaded successfully.";
|
|
129
|
+
return true;
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
void *HookDynamicLoader::GetHooker(const std::string &funcName) {
|
|
133
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
|
134
|
+
auto iter = funcMap_.find(funcName);
|
|
135
|
+
if (iter == funcMap_.end()) {
|
|
136
|
+
MS_LOG(WARNING) << "Function not found: " << funcName;
|
|
137
|
+
return nullptr;
|
|
138
|
+
}
|
|
139
|
+
return iter->second;
|
|
140
|
+
}
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright 2024 Huawei Technologies Co., Ltd
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
#ifndef HOOK_DYNAMIC_LOADER_H
|
|
18
|
+
#define HOOK_DYNAMIC_LOADER_H
|
|
19
|
+
|
|
20
|
+
#include <dlfcn.h>
|
|
21
|
+
#include <string>
|
|
22
|
+
#include <vector>
|
|
23
|
+
#include <map>
|
|
24
|
+
#include <mutex>
|
|
25
|
+
|
|
26
|
+
constexpr auto kHookBegin = "MS_DbgOnStepBegin";
|
|
27
|
+
constexpr auto kHookEnd = "MS_DbgOnStepEnd";
|
|
28
|
+
|
|
29
|
+
class HookDynamicLoader {
|
|
30
|
+
public:
|
|
31
|
+
static HookDynamicLoader &GetInstance();
|
|
32
|
+
|
|
33
|
+
HookDynamicLoader(const HookDynamicLoader &) = delete;
|
|
34
|
+
HookDynamicLoader &operator=(const HookDynamicLoader &) = delete;
|
|
35
|
+
|
|
36
|
+
bool LoadLibrary();
|
|
37
|
+
bool UnloadLibrary();
|
|
38
|
+
void *GetHooker(const std::string &funcName);
|
|
39
|
+
|
|
40
|
+
private:
|
|
41
|
+
// Helper functions
|
|
42
|
+
bool loadFunction(void *handle, const std::string &functionName);
|
|
43
|
+
bool validateLibraryPath(const std::string &libPath);
|
|
44
|
+
|
|
45
|
+
HookDynamicLoader() = default;
|
|
46
|
+
|
|
47
|
+
void *handle_ = nullptr;
|
|
48
|
+
std::vector<std::string> functionList_ = {kHookBegin, kHookEnd};
|
|
49
|
+
std::map<std::string, void *> funcMap_;
|
|
50
|
+
std::mutex mutex_;
|
|
51
|
+
};
|
|
52
|
+
|
|
53
|
+
#endif // HOOK_DYNAMIC_LOADER_H
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
5
|
# you may not use this file except in compliance with the License.
|
|
6
6
|
# You may obtain a copy of the License at
|
|
7
7
|
#
|
|
@@ -13,24 +13,31 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import functools
|
|
16
17
|
import importlib
|
|
17
|
-
import inspect
|
|
18
18
|
import os
|
|
19
|
+
import traceback
|
|
19
20
|
|
|
20
21
|
import mindspore as ms
|
|
21
|
-
from mindspore.communication import comm_func
|
|
22
|
-
|
|
23
22
|
from msprobe.core.common.const import Const
|
|
23
|
+
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
24
24
|
from msprobe.core.common.file_utils import check_path_length, load_yaml
|
|
25
25
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
26
26
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
27
27
|
from msprobe.mindspore.common.log import logger
|
|
28
|
+
from msprobe.mindspore.common.utils import get_rank_if_initialized
|
|
28
29
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
30
|
+
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
31
|
+
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
29
32
|
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
30
|
-
from msprobe.mindspore.free_benchmark.
|
|
33
|
+
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
34
|
+
from msprobe.mindspore.free_benchmark.common.utils import Tools
|
|
35
|
+
from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory
|
|
36
|
+
from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory
|
|
37
|
+
from msprobe.mindspore.runtime import Runtime
|
|
31
38
|
|
|
32
39
|
|
|
33
|
-
class
|
|
40
|
+
class ApiPyNativeSelfCheck:
|
|
34
41
|
def __init__(self, config: DebuggerConfig):
|
|
35
42
|
Config.is_enable = True
|
|
36
43
|
Config.handler_type = config.handler_type
|
|
@@ -39,29 +46,77 @@ class ApiPyNativeSelFCheck:
|
|
|
39
46
|
Config.dump_level = config.dump_level
|
|
40
47
|
Config.steps = config.step
|
|
41
48
|
Config.ranks = config.rank
|
|
42
|
-
Config.dump_path = os.path.join(config.dump_path,
|
|
49
|
+
Config.dump_path = os.path.join(config.dump_path, FreeBenchmarkConst.CHECK_RESULT_FILE)
|
|
43
50
|
check_path_length(Config.dump_path)
|
|
44
51
|
|
|
52
|
+
self.ori_func = {}
|
|
53
|
+
|
|
45
54
|
self.api_list = config.list
|
|
46
55
|
all_api = get_supported_ops()
|
|
47
56
|
if not self.api_list:
|
|
48
57
|
self.api_list = all_api
|
|
49
58
|
else:
|
|
50
59
|
self.api_list = set(self.api_list) & all_api
|
|
60
|
+
self.store_original_func()
|
|
51
61
|
|
|
52
62
|
def handle(self):
|
|
63
|
+
api_register.initialize_hook(self.build_hook)
|
|
64
|
+
api_register.api_set_hook_func()
|
|
65
|
+
|
|
66
|
+
def build_hook(self, api_name):
|
|
67
|
+
def pre_hook(cell, input_data):
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
def forward_hook(api_name_with_id, cell, input_data, output_data):
|
|
71
|
+
ret = None
|
|
72
|
+
|
|
73
|
+
if not need_wrapper_func():
|
|
74
|
+
del cell.input_kwargs
|
|
75
|
+
return ret
|
|
76
|
+
|
|
77
|
+
api_name_with_id = api_name_with_id[:-1]
|
|
78
|
+
hook_prefix = api_name_with_id[:api_name_with_id.find(Const.SEP) + 1]
|
|
79
|
+
api_name = (MsConst.HOOK_MS_PREFIX_DICT.get(hook_prefix, "") +
|
|
80
|
+
api_name_with_id[api_name_with_id.find(Const.SEP) + 1:api_name_with_id.rfind(Const.SEP)])
|
|
81
|
+
if api_name in self.api_list:
|
|
82
|
+
ret = check_self(api_name_with_id, output_data, self.ori_func.get(api_name),
|
|
83
|
+
*input_data, **cell.input_kwargs)
|
|
84
|
+
|
|
85
|
+
del cell.input_kwargs
|
|
86
|
+
return ret
|
|
87
|
+
|
|
88
|
+
def backward_hook(cell, grad_input, grad_output):
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
HOOKCell.get_cell_count(api_name)
|
|
92
|
+
api_name_with_id = api_name + str(HOOKCell.get_cell_count(api_name)) + Const.SEP
|
|
93
|
+
forward_hook = functools.partial(forward_hook, api_name_with_id)
|
|
94
|
+
HOOKCell.add_cell_count(api_name)
|
|
95
|
+
|
|
96
|
+
def wrap_forward_hook(cell, input_data, output_data):
|
|
97
|
+
return forward_hook(cell, input_data, output_data)
|
|
98
|
+
|
|
99
|
+
def wrap_backward_hook(cell, grad_input, grad_output):
|
|
100
|
+
return backward_hook(cell, grad_input, grad_output)
|
|
101
|
+
|
|
102
|
+
def pre_backward_hook(cell, grad_input):
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
return pre_hook, wrap_forward_hook, wrap_backward_hook, pre_backward_hook
|
|
106
|
+
|
|
107
|
+
def store_original_func(self):
|
|
53
108
|
for api_name in self.api_list:
|
|
54
|
-
|
|
109
|
+
self.ori_func[api_name] = get_module(api_name)[1]
|
|
55
110
|
|
|
56
111
|
|
|
57
112
|
def get_supported_ops():
|
|
58
113
|
supported_ops = []
|
|
59
114
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
60
|
-
yaml_path = os.path.join(cur_path, "data",
|
|
115
|
+
yaml_path = os.path.join(cur_path, "data", FreeBenchmarkConst.SUPPORTED_CHECK_API_FILE)
|
|
61
116
|
|
|
62
|
-
|
|
117
|
+
supported_ops_list = load_yaml(yaml_path)
|
|
63
118
|
for k, v in FreeBenchmarkConst.API_PREFIX_DICT.items():
|
|
64
|
-
ops =
|
|
119
|
+
ops = supported_ops_list.get(k)
|
|
65
120
|
if ops:
|
|
66
121
|
ops = [v + i for i in ops]
|
|
67
122
|
supported_ops += ops
|
|
@@ -72,7 +127,7 @@ def get_supported_ops():
|
|
|
72
127
|
_all_functional_ops += ms_ops
|
|
73
128
|
|
|
74
129
|
ms_tensor = dir(ms.Tensor)
|
|
75
|
-
ms_tensor = [MsConst.
|
|
130
|
+
ms_tensor = [MsConst.TENSOR_PREFIX + i for i in ms_tensor]
|
|
76
131
|
_all_functional_ops += ms_tensor
|
|
77
132
|
|
|
78
133
|
ms_mint = dir(ms.mint)
|
|
@@ -83,49 +138,109 @@ def get_supported_ops():
|
|
|
83
138
|
ms_mint_nn_func = [MsConst.MINT_NN_FUNC_PREFIX + i for i in ms_mint_nn_func]
|
|
84
139
|
_all_functional_ops += ms_mint_nn_func
|
|
85
140
|
|
|
86
|
-
ms_communication = dir(comm_func)
|
|
87
|
-
ms_communication = [MsConst.COMM_PREFIX + i for i in ms_communication]
|
|
88
|
-
_all_functional_ops += ms_communication
|
|
89
|
-
|
|
90
141
|
return set(supported_ops) & set(_all_functional_ops)
|
|
91
142
|
|
|
92
143
|
|
|
93
|
-
def get_decorate_func():
|
|
94
|
-
return decorate_forward_function
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def is_func_support_decorate(orig_func):
|
|
98
|
-
return not inspect.isclass(orig_func) and callable(orig_func)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
def get_wrapper_obj(orig_func, api_name):
|
|
102
|
-
if is_func_support_decorate(orig_func):
|
|
103
|
-
wrapped_obj = get_decorate_func()(orig_func, api_name)
|
|
104
|
-
else:
|
|
105
|
-
wrapped_obj = orig_func
|
|
106
|
-
return wrapped_obj
|
|
107
|
-
|
|
108
|
-
|
|
109
144
|
def get_module(api_name):
|
|
110
145
|
func_name_list = api_name.split(Const.SEP)
|
|
111
146
|
func_name = func_name_list[-1]
|
|
112
147
|
module_obj = importlib.import_module(func_name_list[0])
|
|
113
148
|
for i, module_name in enumerate(func_name_list[1:-1]):
|
|
114
149
|
if not hasattr(module_obj, module_name):
|
|
115
|
-
importlib.import_module(f"{Const.SEP.join(func_name_list[:i+2])}")
|
|
150
|
+
importlib.import_module(f"{Const.SEP.join(func_name_list[:i + 2])}")
|
|
116
151
|
module_obj = getattr(module_obj, module_name)
|
|
117
152
|
orig_func = getattr(module_obj, func_name)
|
|
118
153
|
|
|
119
154
|
return module_obj, orig_func
|
|
120
155
|
|
|
121
156
|
|
|
122
|
-
def
|
|
123
|
-
|
|
124
|
-
|
|
157
|
+
def check_self(api_name_with_id, output, ori_func, *args, **kwargs):
|
|
158
|
+
ret = None
|
|
159
|
+
|
|
160
|
+
if Config.stage == Const.BACKWARD and not (check_all_tensor(args) and check_all_tensor(output)):
|
|
161
|
+
logger.warning(f"{api_name_with_id} has non-tensor input or output.")
|
|
162
|
+
return ret
|
|
163
|
+
|
|
164
|
+
params = data_pre_deal(api_name_with_id, ori_func, *args, **kwargs)
|
|
165
|
+
if params.index == -1:
|
|
166
|
+
return ret
|
|
167
|
+
|
|
168
|
+
logger.info(f"[{api_name_with_id}] is {Config.handler_type}ing.")
|
|
169
|
+
api_register.api_set_ori_func()
|
|
170
|
+
|
|
125
171
|
try:
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
172
|
+
perturbation = PerturbationFactory.create(api_name_with_id)
|
|
173
|
+
params.fuzzed_result = perturbation.handle(params)
|
|
174
|
+
if params.fuzzed_result is False:
|
|
175
|
+
api_register.api_set_hook_func()
|
|
176
|
+
return ret
|
|
177
|
+
if Config.stage == Const.BACKWARD:
|
|
178
|
+
params.original_result = Tools.get_grad(params.original_func, *params.args, **params.kwargs)
|
|
179
|
+
else:
|
|
180
|
+
params.original_result = output
|
|
181
|
+
ret = deal_fuzzed_and_original_result(api_name_with_id, params)
|
|
130
182
|
except Exception as e:
|
|
131
|
-
logger.error(f"
|
|
183
|
+
logger.error(f"[{api_name_with_id}] Error: {str(e)}")
|
|
184
|
+
logger.error(f"[{api_name_with_id}] Error detail: {traceback.format_exc()}")
|
|
185
|
+
|
|
186
|
+
api_register.api_set_hook_func()
|
|
187
|
+
return ret
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def check_all_tensor(input_output):
|
|
191
|
+
if isinstance(input_output, ms.Tensor):
|
|
192
|
+
return True
|
|
193
|
+
if isinstance(input_output, (tuple, list)):
|
|
194
|
+
return all([check_all_tensor(v) for v in input_output])
|
|
195
|
+
return False
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def get_target_arg_index(args) -> int:
|
|
199
|
+
"""
|
|
200
|
+
类型校验
|
|
201
|
+
|
|
202
|
+
"""
|
|
203
|
+
for i, arg in enumerate(args):
|
|
204
|
+
if ms.ops.is_tensor(arg):
|
|
205
|
+
if not ms.ops.is_floating_point(arg):
|
|
206
|
+
continue
|
|
207
|
+
return i
|
|
208
|
+
if isinstance(arg, (list, tuple, dict)):
|
|
209
|
+
return i
|
|
210
|
+
return -1
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def data_pre_deal(api_name_with_id, func, *args, **kwargs):
|
|
214
|
+
params = HandlerParams()
|
|
215
|
+
params.args = args
|
|
216
|
+
params.kwargs = kwargs
|
|
217
|
+
params.original_func = func
|
|
218
|
+
index = get_target_arg_index(args)
|
|
219
|
+
if index == -1:
|
|
220
|
+
logger.warning(f"{api_name_with_id} has no supported input type.")
|
|
221
|
+
params.index = index
|
|
222
|
+
return params
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def need_wrapper_func():
|
|
226
|
+
if not (Runtime.is_running and Config.is_enable):
|
|
227
|
+
return False
|
|
228
|
+
|
|
229
|
+
if Config.steps and Runtime.step_count not in Config.steps:
|
|
230
|
+
return False
|
|
231
|
+
|
|
232
|
+
if Runtime.rank_id == -1:
|
|
233
|
+
try:
|
|
234
|
+
Runtime.rank_id = get_rank_if_initialized()
|
|
235
|
+
except DistributedNotInitializedError:
|
|
236
|
+
Runtime.rank_id = -1
|
|
237
|
+
if Config.ranks and Runtime.rank_id != -1 and Runtime.rank_id not in Config.ranks:
|
|
238
|
+
return False
|
|
239
|
+
|
|
240
|
+
return True
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def deal_fuzzed_and_original_result(api_name_with_id, params: HandlerParams):
|
|
244
|
+
handler = HandlerFactory.create(api_name_with_id)
|
|
245
|
+
result = handler.handle(params)
|
|
246
|
+
return result
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
5
|
# you may not use this file except in compliance with the License.
|
|
6
6
|
# You may obtain a copy of the License at
|
|
7
7
|
#
|
|
@@ -27,6 +27,5 @@ class HandlerParams:
|
|
|
27
27
|
original_result: Optional[Any] = None
|
|
28
28
|
fuzzed_result: Optional[Any] = None
|
|
29
29
|
is_consistent: Optional[bool] = True
|
|
30
|
-
save_flag: Optional[bool] = True
|
|
31
30
|
fuzzed_value: Optional[Any] = None
|
|
32
31
|
original_func: Optional[Callable] = None
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
5
|
# you may not use this file except in compliance with the License.
|
|
6
6
|
# You may obtain a copy of the License at
|
|
7
7
|
#
|
|
@@ -17,7 +17,7 @@ from dataclasses import dataclass
|
|
|
17
17
|
from typing import Any, Optional
|
|
18
18
|
|
|
19
19
|
import mindspore as ms
|
|
20
|
-
from mindspore import Tensor
|
|
20
|
+
from mindspore import Tensor, ops
|
|
21
21
|
|
|
22
22
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
23
23
|
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
@@ -43,6 +43,23 @@ class Tools:
|
|
|
43
43
|
return FreeBenchmarkConst.NO_CHANGE_ERROR_THRESHOLD
|
|
44
44
|
return FreeBenchmarkConst.ERROR_THRESHOLD.get(dtype, FreeBenchmarkConst.ERROR_THRESHOLD.get(ms.float32))
|
|
45
45
|
|
|
46
|
+
@staticmethod
|
|
47
|
+
def get_grad_out(outputs):
|
|
48
|
+
if isinstance(outputs, Tensor):
|
|
49
|
+
return ops.ones_like(outputs)
|
|
50
|
+
if isinstance(outputs, (tuple, list)):
|
|
51
|
+
return type(outputs)([Tools.get_grad_out(v) for v in outputs])
|
|
52
|
+
return outputs
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def get_grad(func, *args, **kwargs):
|
|
56
|
+
def target_func(*inputs):
|
|
57
|
+
return func(*inputs, **kwargs)
|
|
58
|
+
|
|
59
|
+
outputs, vjp_fn = ms.vjp(target_func, *args)
|
|
60
|
+
values = Tools.get_grad_out(outputs)
|
|
61
|
+
return vjp_fn(values)
|
|
62
|
+
|
|
46
63
|
|
|
47
64
|
@dataclass
|
|
48
65
|
class UnequalRow:
|
|
@@ -73,10 +90,8 @@ def make_unequal_row(
|
|
|
73
90
|
if isinstance(ratio, float):
|
|
74
91
|
row.max_rel = ratio - 1
|
|
75
92
|
original_tensor = params.original_result
|
|
76
|
-
fuzzed_tensor = params.fuzzed_result
|
|
77
93
|
if index is not None:
|
|
78
94
|
original_tensor = original_tensor[index]
|
|
79
|
-
fuzzed_tensor = fuzzed_tensor[index]
|
|
80
95
|
row.output_index = index
|
|
81
96
|
if isinstance(original_tensor, Tensor):
|
|
82
97
|
row.dtype = original_tensor.dtype
|