mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -21,6 +21,7 @@ from mindspore.common.tensor import Tensor
|
|
|
21
21
|
from msprobe.core.common.utils import Const, DumpException
|
|
22
22
|
from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, ModuleBackwardOutputs,
|
|
23
23
|
ModuleForwardInputsOutputs)
|
|
24
|
+
from msprobe.core.hook_manager import BaseHookManager
|
|
24
25
|
from msprobe.mindspore.common.log import logger
|
|
25
26
|
|
|
26
27
|
|
|
@@ -58,7 +59,7 @@ class PrimitiveHookService:
|
|
|
58
59
|
def backward_hook(grad):
|
|
59
60
|
captured_grads.extend(grad)
|
|
60
61
|
backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}"
|
|
61
|
-
|
|
62
|
+
self.service_instance.inner_switch = True
|
|
62
63
|
try:
|
|
63
64
|
if hook_type == Const.INPUT:
|
|
64
65
|
self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
|
|
@@ -77,6 +78,7 @@ class PrimitiveHookService:
|
|
|
77
78
|
logger.error(f"This is a primitive op {hook_type}_backward dump error: {exception}, "
|
|
78
79
|
f"updated_primitive_name: {updated_primitive_name}")
|
|
79
80
|
raise DumpException(DumpException.BACKWARD_DATA_COLLECTION_ERROR) from exception
|
|
81
|
+
self.service_instance.inner_switch = False
|
|
80
82
|
|
|
81
83
|
return backward_hook
|
|
82
84
|
|
|
@@ -137,6 +139,7 @@ class PrimitiveHookService:
|
|
|
137
139
|
|
|
138
140
|
def pre_forward_hook(primitive_name, primitive_instance, args, kwargs):
|
|
139
141
|
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
|
|
142
|
+
self.service_instance.inner_switch = True
|
|
140
143
|
try:
|
|
141
144
|
self.service_instance.data_collector.forward_input_data_collect(
|
|
142
145
|
primitive_name,
|
|
@@ -148,9 +151,11 @@ class PrimitiveHookService:
|
|
|
148
151
|
logger.error(f"This is a primitive op dump error during forward input data collection: {exception}, "
|
|
149
152
|
f"primitive_name: {primitive_name}")
|
|
150
153
|
raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
|
|
154
|
+
self.service_instance.inner_switch = False
|
|
151
155
|
|
|
152
156
|
def post_forward_hook(primitive_name, primitive_instance, args, kwargs, output):
|
|
153
157
|
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
|
|
158
|
+
self.service_instance.inner_switch = True
|
|
154
159
|
try:
|
|
155
160
|
self.service_instance.data_collector.forward_output_data_collect(
|
|
156
161
|
primitive_name,
|
|
@@ -162,6 +167,7 @@ class PrimitiveHookService:
|
|
|
162
167
|
logger.error(f"This is a primitive op dump error during forward output data collection: {exception}, "
|
|
163
168
|
f"primitive_name: {primitive_name}")
|
|
164
169
|
raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
|
|
170
|
+
self.service_instance.inner_switch = False
|
|
165
171
|
|
|
166
172
|
def wrapped_primitive_call(instance_self, *args, **kwargs):
|
|
167
173
|
"""
|
|
@@ -179,7 +185,7 @@ class PrimitiveHookService:
|
|
|
179
185
|
current_count = self.primitive_counters.get(primitive_name, 0)
|
|
180
186
|
updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}{Const.SEP}{primitive_name}{Const.SEP}{current_count}"
|
|
181
187
|
|
|
182
|
-
if not self.service_instance.primitive_switch:
|
|
188
|
+
if not self.service_instance.primitive_switch or BaseHookManager.inner_switch:
|
|
183
189
|
return origin_func(*args, **kwargs)
|
|
184
190
|
|
|
185
191
|
captured_grads_input, captured_grads_output = [], []
|
|
@@ -564,15 +564,15 @@ tensor:
|
|
|
564
564
|
- all
|
|
565
565
|
- amax
|
|
566
566
|
- amin
|
|
567
|
+
- angle
|
|
567
568
|
- any
|
|
568
569
|
- arccos
|
|
569
570
|
- arccosh
|
|
570
|
-
- argmax
|
|
571
|
-
- angle
|
|
572
571
|
- arcsin
|
|
573
572
|
- arcsinh
|
|
574
573
|
- arctan
|
|
575
574
|
- arctanh
|
|
575
|
+
- argmax
|
|
576
576
|
- argmin
|
|
577
577
|
- argsort
|
|
578
578
|
- asin
|
|
@@ -582,19 +582,23 @@ tensor:
|
|
|
582
582
|
- atanh
|
|
583
583
|
- baddbmm
|
|
584
584
|
- bernoulli
|
|
585
|
+
- bfloat16
|
|
585
586
|
- bincount
|
|
586
587
|
- bitwise_and
|
|
587
588
|
- bitwise_or
|
|
588
589
|
- bitwise_xor
|
|
589
590
|
- bmm
|
|
590
591
|
- bool
|
|
592
|
+
- bool astype
|
|
591
593
|
- broadcast_to
|
|
594
|
+
- byte
|
|
592
595
|
- ceil
|
|
593
|
-
- cholesky_solve
|
|
594
596
|
- cholesky
|
|
597
|
+
- cholesky_solve
|
|
595
598
|
- clamp
|
|
596
599
|
- clip
|
|
597
600
|
- conj
|
|
601
|
+
- copy
|
|
598
602
|
- copysign
|
|
599
603
|
- cos
|
|
600
604
|
- cosh
|
|
@@ -606,11 +610,13 @@ tensor:
|
|
|
606
610
|
- deg2rad
|
|
607
611
|
- diag
|
|
608
612
|
- diagflat
|
|
613
|
+
- diagonal
|
|
609
614
|
- diff
|
|
610
615
|
- digamma
|
|
611
616
|
- div
|
|
612
617
|
- div_
|
|
613
618
|
- divide
|
|
619
|
+
- double
|
|
614
620
|
- equal
|
|
615
621
|
- erf
|
|
616
622
|
- erfc
|
|
@@ -618,13 +624,16 @@ tensor:
|
|
|
618
624
|
- exp
|
|
619
625
|
- expand_as
|
|
620
626
|
- expm1
|
|
627
|
+
- flatten
|
|
621
628
|
- flip
|
|
622
629
|
- fliplr
|
|
623
630
|
- flipud
|
|
631
|
+
- float
|
|
624
632
|
- float_power
|
|
625
633
|
- floor
|
|
626
634
|
- fmod
|
|
627
635
|
- frac
|
|
636
|
+
- from_numpy
|
|
628
637
|
- gather_elements
|
|
629
638
|
- ge
|
|
630
639
|
- geqrf
|
|
@@ -648,12 +657,12 @@ tensor:
|
|
|
648
657
|
- inner
|
|
649
658
|
- int
|
|
650
659
|
- inverse
|
|
660
|
+
- is_complex
|
|
661
|
+
- is_signed
|
|
651
662
|
- isclose
|
|
652
663
|
- isfinite
|
|
653
664
|
- isinf
|
|
654
665
|
- isnan
|
|
655
|
-
- is_complex
|
|
656
|
-
- is_signed
|
|
657
666
|
- isneginf
|
|
658
667
|
- isposinf
|
|
659
668
|
- isreal
|
|
@@ -704,28 +713,27 @@ tensor:
|
|
|
704
713
|
- new_ones
|
|
705
714
|
- new_zeros
|
|
706
715
|
- nextafter
|
|
707
|
-
- norm
|
|
708
716
|
- nonzero
|
|
717
|
+
- norm
|
|
709
718
|
- not_equal
|
|
710
719
|
- ormqr
|
|
711
720
|
- permute
|
|
712
721
|
- pow
|
|
713
722
|
- prod
|
|
714
723
|
- qr
|
|
724
|
+
- rad2deg
|
|
715
725
|
- ravel
|
|
716
726
|
- real
|
|
717
727
|
- reciprocal
|
|
718
728
|
- remainder
|
|
719
729
|
- renorm
|
|
720
|
-
- rad2deg
|
|
721
|
-
- tile
|
|
722
730
|
- repeat_interleave
|
|
723
731
|
- reshape
|
|
724
732
|
- reshape
|
|
725
|
-
-
|
|
733
|
+
- resize
|
|
726
734
|
- rot90
|
|
735
|
+
- round
|
|
727
736
|
- rsqrt
|
|
728
|
-
- sum_to_size
|
|
729
737
|
- scatter
|
|
730
738
|
- sgn
|
|
731
739
|
- short
|
|
@@ -745,7 +753,8 @@ tensor:
|
|
|
745
753
|
- sub
|
|
746
754
|
- sub_
|
|
747
755
|
- subtract
|
|
748
|
-
-
|
|
756
|
+
- sum
|
|
757
|
+
- sum_to_size
|
|
749
758
|
- svd
|
|
750
759
|
- swapaxes
|
|
751
760
|
- swapdims
|
|
@@ -753,13 +762,13 @@ tensor:
|
|
|
753
762
|
- take
|
|
754
763
|
- tan
|
|
755
764
|
- tanh
|
|
756
|
-
-
|
|
757
|
-
- swapaxes
|
|
765
|
+
- tensor_split
|
|
758
766
|
- tile
|
|
767
|
+
- to
|
|
759
768
|
- topk
|
|
760
|
-
-
|
|
761
|
-
- tensor_split
|
|
769
|
+
- trace
|
|
762
770
|
- transpose
|
|
771
|
+
- tril
|
|
763
772
|
- true_divide
|
|
764
773
|
- trunc
|
|
765
774
|
- unbind
|
|
@@ -769,17 +778,6 @@ tensor:
|
|
|
769
778
|
- view
|
|
770
779
|
- where
|
|
771
780
|
- xlogy
|
|
772
|
-
- from_numpy
|
|
773
|
-
- std
|
|
774
|
-
- take
|
|
775
|
-
- var
|
|
776
|
-
- all
|
|
777
|
-
- any
|
|
778
|
-
- copy
|
|
779
|
-
- diagonal
|
|
780
|
-
- flatten
|
|
781
|
-
- resize
|
|
782
|
-
- sum
|
|
783
781
|
|
|
784
782
|
mint.ops:
|
|
785
783
|
- abs
|
|
@@ -1027,3 +1025,21 @@ communication.comm_func:
|
|
|
1027
1025
|
- recv
|
|
1028
1026
|
- isend
|
|
1029
1027
|
- irecv
|
|
1028
|
+
|
|
1029
|
+
mint.distributed:
|
|
1030
|
+
- send
|
|
1031
|
+
- recv
|
|
1032
|
+
- broadcast
|
|
1033
|
+
- all_reduce
|
|
1034
|
+
- reduce
|
|
1035
|
+
- all_gather
|
|
1036
|
+
- gather
|
|
1037
|
+
- isend
|
|
1038
|
+
- irecv
|
|
1039
|
+
- scatter
|
|
1040
|
+
- reduce_scatter
|
|
1041
|
+
- all_to_all_single
|
|
1042
|
+
- all_to_all
|
|
1043
|
+
- all_gather_into_tensor
|
|
1044
|
+
- reduce_scatter_tensor
|
|
1045
|
+
- batch_isend_irecv
|
|
@@ -13,9 +13,12 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import os
|
|
17
16
|
from collections import defaultdict
|
|
17
|
+
import os
|
|
18
|
+
import types
|
|
18
19
|
|
|
20
|
+
import mindspore
|
|
21
|
+
from mindspore import nn
|
|
19
22
|
from mindspore._c_expression import PyNativeExecutor_
|
|
20
23
|
try:
|
|
21
24
|
from mindspore.common.api import _MindsporeFunctionExecutor
|
|
@@ -24,30 +27,31 @@ except ImportError:
|
|
|
24
27
|
|
|
25
28
|
from msprobe.core.common.log import logger
|
|
26
29
|
from msprobe.core.common.const import Const
|
|
30
|
+
from msprobe.core.common.runtime import Runtime
|
|
27
31
|
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
28
|
-
from msprobe.mindspore.
|
|
32
|
+
from msprobe.mindspore.common.const import Const as MsConst
|
|
33
|
+
from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
_api_register = get_api_register()
|
|
29
37
|
|
|
30
38
|
|
|
31
39
|
def dump_jit(name, in_feat, out_feat, is_forward):
|
|
32
40
|
pid = os.getpid()
|
|
33
|
-
|
|
34
|
-
index = ori_args.find("<")
|
|
35
|
-
if index != 0 and index != -1:
|
|
36
|
-
result = ori_args[0:index]
|
|
37
|
-
elif name is not None and "<" not in str(name):
|
|
38
|
-
result = str(name)
|
|
39
|
-
else:
|
|
40
|
-
result = "JitFunction"
|
|
41
|
+
name = name if name else "JitFunction"
|
|
41
42
|
if JitDump.need_dump():
|
|
42
43
|
if is_forward:
|
|
43
|
-
JitDump.jit_count
|
|
44
|
-
|
|
45
|
-
|
|
44
|
+
if name in JitDump.jit_count:
|
|
45
|
+
JitDump.jit_count[name] += 1
|
|
46
|
+
else:
|
|
47
|
+
JitDump.jit_count[name] = 0
|
|
48
|
+
name_template = (Const.JIT + Const.SEP + name + Const.SEP +
|
|
49
|
+
str(JitDump.jit_count[name]) + Const.SEP + Const.FORWARD)
|
|
46
50
|
JitDump.data_collector.update_api_or_module_name(name_template)
|
|
47
51
|
module_input_output = ModuleForwardInputsOutputs(args=in_feat, kwargs={}, output=out_feat)
|
|
48
52
|
JitDump.data_collector.forward_data_collect(name_template, None, pid, module_input_output)
|
|
49
53
|
else:
|
|
50
|
-
name_template = Const.JIT + Const.SEP +
|
|
54
|
+
name_template = Const.JIT + Const.SEP + name + Const.SEP + str(JitDump.jit_count[name]) + Const.SEP + \
|
|
51
55
|
Const.BACKWARD
|
|
52
56
|
JitDump.data_collector.update_api_or_module_name(name_template)
|
|
53
57
|
module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat, grad_output=out_feat)
|
|
@@ -57,7 +61,7 @@ def dump_jit(name, in_feat, out_feat, is_forward):
|
|
|
57
61
|
class JitDump(_MindsporeFunctionExecutor):
|
|
58
62
|
dump_config = None
|
|
59
63
|
jit_enable = False
|
|
60
|
-
jit_dump_switch =
|
|
64
|
+
jit_dump_switch = False
|
|
61
65
|
jit_count = defaultdict(int)
|
|
62
66
|
|
|
63
67
|
def __init__(self, *args, **kwargs):
|
|
@@ -68,19 +72,17 @@ class JitDump(_MindsporeFunctionExecutor):
|
|
|
68
72
|
self._executor = PyNativeExecutor_.get_instance()
|
|
69
73
|
|
|
70
74
|
def __call__(self, *args, **kwargs):
|
|
71
|
-
|
|
72
|
-
api_register.api_set_ori_func()
|
|
75
|
+
_api_register.restore_all_api()
|
|
73
76
|
out = super().__call__(*args, **kwargs)
|
|
74
|
-
if JitDump.jit_dump_switch and len(args) > 0:
|
|
75
|
-
if self.name
|
|
77
|
+
if JitDump.jit_dump_switch and len(args) > 0 and self.name:
|
|
78
|
+
if self.name != "construct":
|
|
76
79
|
dump_jit(self.name, args, out, True)
|
|
77
|
-
|
|
78
|
-
dump_jit(args[0], args, out, True)
|
|
80
|
+
elif Runtime.run_mode != MsConst.PYNATIVE_GRAPH_MODE and isinstance(args[0], nn.Cell):
|
|
81
|
+
dump_jit(args[0].__class__.__name__, args, out, True)
|
|
79
82
|
JitDump.jit_enable = True
|
|
80
83
|
elif len(args) == 0:
|
|
81
84
|
logger.warning(f"The jit function {self.name} has no input arguments, nothing will be dumped.")
|
|
82
|
-
|
|
83
|
-
api_register.api_set_hook_func()
|
|
85
|
+
_api_register.register_all_api()
|
|
84
86
|
return out
|
|
85
87
|
|
|
86
88
|
@classmethod
|
|
@@ -101,9 +103,15 @@ class JitDump(_MindsporeFunctionExecutor):
|
|
|
101
103
|
|
|
102
104
|
def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
|
|
103
105
|
if JitDump.jit_dump_switch and JitDump.jit_enable:
|
|
104
|
-
|
|
105
|
-
|
|
106
|
+
_api_register.restore_all_api()
|
|
107
|
+
if mindspore.__version__ >= "2.5":
|
|
108
|
+
output = self._executor.grad(grad, obj, weights, grad_position, False, *args, *(kwargs.values()))
|
|
109
|
+
else:
|
|
110
|
+
output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
|
|
106
111
|
if JitDump.jit_dump_switch and JitDump.jit_enable:
|
|
107
|
-
|
|
108
|
-
|
|
112
|
+
if isinstance(obj, types.FunctionType):
|
|
113
|
+
dump_jit(obj.__name__, args, None, False)
|
|
114
|
+
elif Runtime.run_mode != MsConst.PYNATIVE_GRAPH_MODE and isinstance(obj, nn.Cell):
|
|
115
|
+
dump_jit(obj.__class__.__name__, args, None, False)
|
|
116
|
+
_api_register.register_all_api()
|
|
109
117
|
return output
|
|
@@ -39,9 +39,12 @@ class KernelKbykDump:
|
|
|
39
39
|
common_set["input_output"] = 0
|
|
40
40
|
common_set["kernels"] = []
|
|
41
41
|
common_set["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7]
|
|
42
|
-
e2e_set =
|
|
43
|
-
|
|
44
|
-
|
|
42
|
+
e2e_set = {
|
|
43
|
+
"enable": not config.async_dump,
|
|
44
|
+
"trans_flag": True,
|
|
45
|
+
"stat_calc_mode": config.stat_cal_mode,
|
|
46
|
+
"device_stat_precision_mode": config.device_stat_precision_mode,
|
|
47
|
+
}
|
|
45
48
|
|
|
46
49
|
if config.list:
|
|
47
50
|
common_set["dump_mode"] = 1
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright 2024 Huawei Technologies Co., Ltd
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
#include "hook_dynamic_loader.h"
|
|
18
|
+
#include <sys/stat.h>
|
|
19
|
+
#include <cstdlib>
|
|
20
|
+
#include <cstring>
|
|
21
|
+
#include <pybind11/embed.h>
|
|
22
|
+
#include "utils/log_adapter.h"
|
|
23
|
+
|
|
24
|
+
namespace py = pybind11;
|
|
25
|
+
|
|
26
|
+
HookDynamicLoader &HookDynamicLoader::GetInstance()
|
|
27
|
+
{
|
|
28
|
+
static HookDynamicLoader instance;
|
|
29
|
+
return instance;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
bool HookDynamicLoader::LoadFunction(void *handle, const std::string &functionName) {
|
|
33
|
+
void *func = dlsym(handle, functionName.c_str());
|
|
34
|
+
if (!func) {
|
|
35
|
+
MS_LOG(WARNING) << "Could not load function: " << functionName << ", error: " << dlerror();
|
|
36
|
+
return false;
|
|
37
|
+
}
|
|
38
|
+
funcMap_[functionName] = func;
|
|
39
|
+
return true;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
bool HookDynamicLoader::LoadLibrary()
|
|
43
|
+
{
|
|
44
|
+
std::string msprobePath = "";
|
|
45
|
+
// 获取gil锁
|
|
46
|
+
py::gil_scoped_acquire acquire;
|
|
47
|
+
try {
|
|
48
|
+
py::module msprobeMod = py::module::import("msprobe.lib._msprobe_c");
|
|
49
|
+
if (!py::hasattr(msprobeMod, "__file__")) {
|
|
50
|
+
MS_LOG(WARNING) << "Adump mod not found";
|
|
51
|
+
return false;
|
|
52
|
+
}
|
|
53
|
+
msprobePath = msprobeMod.attr("__file__").cast<std::string>();
|
|
54
|
+
} catch (const std::exception& e) {
|
|
55
|
+
MS_LOG(WARNING) << "Adump mod path unable to get: " << e.what();
|
|
56
|
+
return false;
|
|
57
|
+
}
|
|
58
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
|
59
|
+
if (handle_) {
|
|
60
|
+
MS_LOG(WARNING) << "Hook library already loaded!";
|
|
61
|
+
return false;
|
|
62
|
+
}
|
|
63
|
+
if (msprobePath == "") {
|
|
64
|
+
MS_LOG(WARNING) << "Adump path not loaded";
|
|
65
|
+
return false;
|
|
66
|
+
}
|
|
67
|
+
handle_ = dlopen(msprobePath.c_str(), RTLD_LAZY | RTLD_LOCAL);
|
|
68
|
+
if (!handle_) {
|
|
69
|
+
MS_LOG(WARNING) << "Failed to load Hook library: " << dlerror();
|
|
70
|
+
return false;
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
for (const auto &functionName : functionList_) {
|
|
74
|
+
if (!LoadFunction(handle_, functionName)) {
|
|
75
|
+
MS_LOG(WARNING) << "Failed to load adump function";
|
|
76
|
+
dlclose(handle_);
|
|
77
|
+
handle_ = nullptr;
|
|
78
|
+
return false;
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
MS_LOG(INFO) << "Hook library loaded successfully.";
|
|
83
|
+
return true;
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
bool HookDynamicLoader::UnloadLibrary()
|
|
87
|
+
{
|
|
88
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
|
89
|
+
if (!handle_) {
|
|
90
|
+
MS_LOG(WARNING) << "Hook library hasn't been loaded.";
|
|
91
|
+
return false;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
dlclose(handle_);
|
|
95
|
+
handle_ = nullptr;
|
|
96
|
+
funcMap_.clear();
|
|
97
|
+
MS_LOG(INFO) << "Library unloaded successfully.";
|
|
98
|
+
return true;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
void *HookDynamicLoader::GetHooker(const std::string &funcName)
|
|
102
|
+
{
|
|
103
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
|
104
|
+
auto iter = funcMap_.find(funcName);
|
|
105
|
+
if (iter == funcMap_.end()) {
|
|
106
|
+
MS_LOG(WARNING) << "Function not found: " << funcName;
|
|
107
|
+
return nullptr;
|
|
108
|
+
}
|
|
109
|
+
return iter->second;
|
|
110
|
+
}
|
|
@@ -27,27 +27,26 @@ constexpr auto kHookBegin = "MS_DbgOnStepBegin";
|
|
|
27
27
|
constexpr auto kHookEnd = "MS_DbgOnStepEnd";
|
|
28
28
|
|
|
29
29
|
class HookDynamicLoader {
|
|
30
|
-
|
|
31
|
-
|
|
30
|
+
public:
|
|
31
|
+
static HookDynamicLoader &GetInstance();
|
|
32
32
|
|
|
33
|
-
|
|
34
|
-
|
|
33
|
+
HookDynamicLoader(const HookDynamicLoader &) = delete;
|
|
34
|
+
HookDynamicLoader &operator=(const HookDynamicLoader &) = delete;
|
|
35
35
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
36
|
+
bool LoadLibrary();
|
|
37
|
+
bool UnloadLibrary();
|
|
38
|
+
void *GetHooker(const std::string &funcName);
|
|
39
39
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
bool validateLibraryPath(const std::string &libPath);
|
|
40
|
+
private:
|
|
41
|
+
// Helper functions
|
|
42
|
+
bool LoadFunction(void *handle, const std::string &functionName);
|
|
44
43
|
|
|
45
|
-
|
|
44
|
+
HookDynamicLoader() = default;
|
|
46
45
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
46
|
+
void *handle_ = nullptr;
|
|
47
|
+
std::vector<std::string> functionList_ = {kHookBegin, kHookEnd};
|
|
48
|
+
std::map<std::string, void *> funcMap_;
|
|
49
|
+
std::mutex mutex_;
|
|
51
50
|
};
|
|
52
51
|
|
|
53
52
|
#endif // HOOK_DYNAMIC_LOADER_H
|
|
@@ -19,22 +19,27 @@ import os
|
|
|
19
19
|
import traceback
|
|
20
20
|
|
|
21
21
|
import mindspore as ms
|
|
22
|
+
|
|
22
23
|
from msprobe.core.common.const import Const
|
|
23
24
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
24
25
|
from msprobe.core.common.file_utils import check_path_length, load_yaml
|
|
26
|
+
from msprobe.core.common.runtime import Runtime
|
|
27
|
+
from msprobe.core.hook_manager import HookSet
|
|
25
28
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
26
29
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
27
30
|
from msprobe.mindspore.common.log import logger
|
|
28
31
|
from msprobe.mindspore.common.utils import get_rank_if_initialized
|
|
29
32
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
30
|
-
from msprobe.mindspore.dump.hook_cell.
|
|
33
|
+
from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
|
|
31
34
|
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
32
35
|
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
33
36
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
34
37
|
from msprobe.mindspore.free_benchmark.common.utils import Tools
|
|
35
38
|
from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory
|
|
36
39
|
from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory
|
|
37
|
-
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
_api_register = get_api_register()
|
|
38
43
|
|
|
39
44
|
|
|
40
45
|
class ApiPyNativeSelfCheck:
|
|
@@ -60,8 +65,8 @@ class ApiPyNativeSelfCheck:
|
|
|
60
65
|
self.store_original_func()
|
|
61
66
|
|
|
62
67
|
def handle(self):
|
|
63
|
-
|
|
64
|
-
|
|
68
|
+
_api_register.initialize_hook(self.build_hook)
|
|
69
|
+
_api_register.register_all_api()
|
|
65
70
|
|
|
66
71
|
def build_hook(self, api_name):
|
|
67
72
|
def pre_hook(cell, input_data):
|
|
@@ -71,7 +76,7 @@ class ApiPyNativeSelfCheck:
|
|
|
71
76
|
ret = None
|
|
72
77
|
|
|
73
78
|
if not need_wrapper_func():
|
|
74
|
-
del cell.
|
|
79
|
+
del cell.msprobe_input_kwargs
|
|
75
80
|
return ret
|
|
76
81
|
|
|
77
82
|
api_name_with_id = api_name_with_id[:-1]
|
|
@@ -80,9 +85,9 @@ class ApiPyNativeSelfCheck:
|
|
|
80
85
|
api_name_with_id[api_name_with_id.find(Const.SEP) + 1:api_name_with_id.rfind(Const.SEP)])
|
|
81
86
|
if api_name in self.api_list:
|
|
82
87
|
ret = check_self(api_name_with_id, output_data, self.ori_func.get(api_name),
|
|
83
|
-
*input_data, **cell.
|
|
88
|
+
*input_data, **cell.msprobe_input_kwargs)
|
|
84
89
|
|
|
85
|
-
del cell.
|
|
90
|
+
del cell.msprobe_input_kwargs
|
|
86
91
|
return ret
|
|
87
92
|
|
|
88
93
|
def backward_hook(cell, grad_input, grad_output):
|
|
@@ -101,8 +106,13 @@ class ApiPyNativeSelfCheck:
|
|
|
101
106
|
|
|
102
107
|
def pre_backward_hook(cell, grad_input):
|
|
103
108
|
return None
|
|
104
|
-
|
|
105
|
-
return
|
|
109
|
+
|
|
110
|
+
return HookSet(
|
|
111
|
+
forward_hook=wrap_forward_hook,
|
|
112
|
+
forward_pre_hook=pre_hook,
|
|
113
|
+
backward_hook=wrap_backward_hook,
|
|
114
|
+
backward_pre_hook=pre_backward_hook
|
|
115
|
+
)
|
|
106
116
|
|
|
107
117
|
def store_original_func(self):
|
|
108
118
|
for api_name in self.api_list:
|
|
@@ -166,13 +176,13 @@ def check_self(api_name_with_id, output, ori_func, *args, **kwargs):
|
|
|
166
176
|
return ret
|
|
167
177
|
|
|
168
178
|
logger.info(f"[{api_name_with_id}] is {Config.handler_type}ing.")
|
|
169
|
-
|
|
179
|
+
_api_register.restore_all_api()
|
|
170
180
|
|
|
171
181
|
try:
|
|
172
182
|
perturbation = PerturbationFactory.create(api_name_with_id)
|
|
173
183
|
params.fuzzed_result = perturbation.handle(params)
|
|
174
184
|
if params.fuzzed_result is False:
|
|
175
|
-
|
|
185
|
+
_api_register.register_all_api()
|
|
176
186
|
return ret
|
|
177
187
|
if Config.stage == Const.BACKWARD:
|
|
178
188
|
params.original_result = Tools.get_grad(params.original_func, *params.args, **params.kwargs)
|
|
@@ -183,7 +193,7 @@ def check_self(api_name_with_id, output, ori_func, *args, **kwargs):
|
|
|
183
193
|
logger.error(f"[{api_name_with_id}] Error: {str(e)}")
|
|
184
194
|
logger.error(f"[{api_name_with_id}] Error detail: {traceback.format_exc()}")
|
|
185
195
|
|
|
186
|
-
|
|
196
|
+
_api_register.register_all_api()
|
|
187
197
|
return ret
|
|
188
198
|
|
|
189
199
|
|
|
@@ -19,10 +19,10 @@ from typing import Any, Optional
|
|
|
19
19
|
import mindspore as ms
|
|
20
20
|
from mindspore import Tensor, ops
|
|
21
21
|
|
|
22
|
+
from msprobe.core.common.runtime import Runtime
|
|
22
23
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
23
24
|
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
24
25
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
25
|
-
from msprobe.mindspore.runtime import Runtime
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class Tools:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
17
|
+
from msprobe.mindspore.common.log import logger
|
|
17
18
|
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
18
19
|
from msprobe.mindspore.free_benchmark.perturbation.add_noise import AddNoisePerturbation
|
|
19
20
|
from msprobe.mindspore.free_benchmark.perturbation.bit_noise import BitNoisePerturbation
|
|
@@ -41,4 +42,5 @@ class PerturbationFactory:
|
|
|
41
42
|
if perturbation:
|
|
42
43
|
return perturbation(api_name_with_id)
|
|
43
44
|
else:
|
|
44
|
-
|
|
45
|
+
logger.error(f'{Config.pert_type} is a invalid perturbation type')
|
|
46
|
+
raise ValueError
|