mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -1,146 +0,0 @@
|
|
|
1
|
-
import re
|
|
2
|
-
|
|
3
|
-
from msprobe.core.common.const import Const
|
|
4
|
-
from msprobe.core.common.log import logger
|
|
5
|
-
from msprobe.core.common.utils import CompareException
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class Trie:
|
|
9
|
-
def __init__(self, type_name=None, has_data=False):
|
|
10
|
-
self.type_name = type_name
|
|
11
|
-
self.call_count_list = []
|
|
12
|
-
self.children = {}
|
|
13
|
-
self.has_data = has_data
|
|
14
|
-
self.node_type = None
|
|
15
|
-
|
|
16
|
-
def __repr__(self):
|
|
17
|
-
return (f"Node(type_name={self.type_name}, "
|
|
18
|
-
f"has_data={self.has_data}, call number={len(self.call_count_list)})")
|
|
19
|
-
|
|
20
|
-
def insert(self, word, word_type="func"):
|
|
21
|
-
parts = word.split(Const.SEP)
|
|
22
|
-
if len(parts) < 2:
|
|
23
|
-
logger.error('result dataframe elements can not be access.')
|
|
24
|
-
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
25
|
-
"""
|
|
26
|
-
xxx, node_name, type_name, execute_num
|
|
27
|
-
etc: Cell.network_with_loss.language_model.encoder.layers.1.attention.out_proj.RowParallelLinear.1
|
|
28
|
-
prefix_name_list: Cell.network_with_loss.language_model.encoder.layers.1.attention
|
|
29
|
-
node_name: out_proj
|
|
30
|
-
type_name: RowParallelLinear
|
|
31
|
-
call_count: 1
|
|
32
|
-
"""
|
|
33
|
-
type_name = parts[-2]
|
|
34
|
-
call_count = parts[-1]
|
|
35
|
-
node = self
|
|
36
|
-
prefix_name_list = parts[:-2]
|
|
37
|
-
|
|
38
|
-
for name in prefix_name_list:
|
|
39
|
-
if name not in node.children:
|
|
40
|
-
node.children[name] = Trie()
|
|
41
|
-
node = node.children[name]
|
|
42
|
-
if node.type_name is None:
|
|
43
|
-
node.type_name = name
|
|
44
|
-
|
|
45
|
-
node.type_name = type_name
|
|
46
|
-
node.has_data = True
|
|
47
|
-
node.call_count_list.append(call_count)
|
|
48
|
-
node.node_type = word_type
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
class DFSConverter:
|
|
52
|
-
def __init__(self, mapping, max_depth=100):
|
|
53
|
-
self.mapping = mapping
|
|
54
|
-
self.max_depth = max_depth
|
|
55
|
-
self.result = {}
|
|
56
|
-
|
|
57
|
-
def traverse_and_collect(self, node, path="", mapping_path="", depth=0):
|
|
58
|
-
if depth > self.max_depth:
|
|
59
|
-
logger.error("The converted data depth is too large, please check the data")
|
|
60
|
-
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
61
|
-
|
|
62
|
-
if node is None:
|
|
63
|
-
return self.result
|
|
64
|
-
|
|
65
|
-
type_name = node.type_name
|
|
66
|
-
if node.has_data:
|
|
67
|
-
for count in node.call_count_list:
|
|
68
|
-
origin_name = f"{path}.{count}" if node.node_type == "Cell" else f"{path}.{type_name}.{count}"
|
|
69
|
-
mapping_name = f"{mapping_path}.{count}" if node.node_type == "Cell" else f"{mapping_path}.{type_name}.{count}"
|
|
70
|
-
self.result[origin_name] = mapping_name
|
|
71
|
-
|
|
72
|
-
name_mapping = self.mapping.get(type_name, {})
|
|
73
|
-
|
|
74
|
-
for child_name, child_node in node.children.items():
|
|
75
|
-
new_path = f"{path}.{child_name}" if path else child_name
|
|
76
|
-
converted_name = name_mapping.get(child_name, child_name)
|
|
77
|
-
new_mapping_path = f"{mapping_path}.{converted_name}" if mapping_path else converted_name
|
|
78
|
-
self.traverse_and_collect(child_node, new_path, new_mapping_path, depth+1)
|
|
79
|
-
|
|
80
|
-
return self.result
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def get_mapping_list(ms_tree, mapping):
|
|
84
|
-
dfs_converter = DFSConverter(mapping)
|
|
85
|
-
ms_pt_mapping = dfs_converter.traverse_and_collect(ms_tree)
|
|
86
|
-
mapping_list = []
|
|
87
|
-
for ms_name, pt_name in ms_pt_mapping.items():
|
|
88
|
-
pt_name = re.sub(r"^Cell", "Module", pt_name)
|
|
89
|
-
mapping_list.append((ms_name, pt_name))
|
|
90
|
-
return mapping_list
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def get_prefix_mapping(scope_list):
|
|
94
|
-
"""layer name to layer name.class_name"""
|
|
95
|
-
layer_mapping = {}
|
|
96
|
-
for name, v in scope_list.items():
|
|
97
|
-
origin_data = v.get("origin_data")
|
|
98
|
-
if not origin_data.startswith(("Cell", "Module")):
|
|
99
|
-
continue
|
|
100
|
-
name_list = name.split(Const.SEP)
|
|
101
|
-
if len(name_list) < 2:
|
|
102
|
-
logger.error('result dataframe elements can not be access.')
|
|
103
|
-
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
104
|
-
prefix_name_list = name_list[:-2] + [name_list[-1]]
|
|
105
|
-
prefix_name = Const.SEP.join(prefix_name_list)
|
|
106
|
-
layer_mapping[prefix_name] = name
|
|
107
|
-
return layer_mapping
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
def get_layer_mapping(ms_scope_list, pt_scope_list, mapping):
|
|
111
|
-
# 1. get layer prefix to full name mapping
|
|
112
|
-
# ect: Cell.network_with_loss.language_model.embedding.3 : Cell.network_with_loss.language_model.embedding.Embedding.3
|
|
113
|
-
ms_prefix2fullname = get_prefix_mapping(ms_scope_list)
|
|
114
|
-
# 2. build trie tree
|
|
115
|
-
ms_tree = Trie(type_name="Cell")
|
|
116
|
-
for k, r in ms_scope_list.items():
|
|
117
|
-
origin_data_name = r.get('origin_data')
|
|
118
|
-
data_type = origin_data_name.split(Const.SEP)[0]
|
|
119
|
-
ms_tree.insert(k, data_type)
|
|
120
|
-
msname2ptname = get_mapping_list(ms_tree, mapping)
|
|
121
|
-
# 3. get pt layer prefix to full name mapping
|
|
122
|
-
# ect: Module.network_with_loss.language_model.embedding.3 : Module.network_with_loss.language_model.embedding.Embedding.3
|
|
123
|
-
pt_prefix2fullname = get_prefix_mapping(pt_scope_list)
|
|
124
|
-
|
|
125
|
-
final_mapping = []
|
|
126
|
-
for ms_name, pt_name in msname2ptname:
|
|
127
|
-
final_ms_name = ms_name
|
|
128
|
-
final_pt_name = pt_name
|
|
129
|
-
# cell
|
|
130
|
-
if ms_name in ms_prefix2fullname:
|
|
131
|
-
final_ms_name = ms_prefix2fullname.get(ms_name)
|
|
132
|
-
final_pt_name = pt_prefix2fullname.get(pt_name, None)
|
|
133
|
-
# func
|
|
134
|
-
elif final_ms_name in ms_scope_list:
|
|
135
|
-
final_ms_name = ms_scope_list.get(ms_name)['origin_data']
|
|
136
|
-
# remove forward/backward
|
|
137
|
-
final_ms_name = Const.SEP.join(final_ms_name.split(Const.SEP)[:-1])
|
|
138
|
-
final_pt_name = pt_scope_list.get(pt_name, None)
|
|
139
|
-
if final_pt_name:
|
|
140
|
-
final_pt_name = final_pt_name['origin_data']
|
|
141
|
-
final_pt_name = Const.SEP.join(final_pt_name.split(Const.SEP)[:-1])
|
|
142
|
-
else:
|
|
143
|
-
continue
|
|
144
|
-
final_mapping.append((final_ms_name, final_pt_name))
|
|
145
|
-
|
|
146
|
-
return final_mapping
|
|
@@ -1,107 +0,0 @@
|
|
|
1
|
-
from msprobe.core.common.const import Const
|
|
2
|
-
from msprobe.core.common.log import logger
|
|
3
|
-
|
|
4
|
-
def find_regard_scope(lines, start_sign, end_sign):
|
|
5
|
-
# 找出 start_pos 和 end_pos
|
|
6
|
-
start_pos = end_pos = -1
|
|
7
|
-
for idx, ii in enumerate(lines):
|
|
8
|
-
if start_sign in ii:
|
|
9
|
-
start_pos = idx
|
|
10
|
-
elif end_sign in ii:
|
|
11
|
-
end_pos = idx
|
|
12
|
-
break
|
|
13
|
-
return start_pos, end_pos
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def find_stack_func_list(lines):
|
|
17
|
-
res_list = []
|
|
18
|
-
# 过滤和处理 regard_scope
|
|
19
|
-
for line in lines:
|
|
20
|
-
ele_list = line.split(',')
|
|
21
|
-
file_ele = ele_list[Const.STACK_FILE_INDEX]
|
|
22
|
-
if any(ii in file_ele for ii in Const.FILE_SKIP_LIST):
|
|
23
|
-
continue
|
|
24
|
-
|
|
25
|
-
func_ele = ele_list[Const.STACK_FUNC_INDEX]
|
|
26
|
-
if any(ii in func_ele for ii in Const.FUNC_SKIP_LIST):
|
|
27
|
-
continue
|
|
28
|
-
|
|
29
|
-
in_func_name = func_ele.split()[Const.STACK_FUNC_ELE_INDEX]
|
|
30
|
-
|
|
31
|
-
res_list.append(in_func_name)
|
|
32
|
-
# 反转res_list并生成final_res
|
|
33
|
-
reversed_list = res_list[::-1]
|
|
34
|
-
return reversed_list
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def get_duplicated_name(components):
|
|
38
|
-
duplicated_components = components
|
|
39
|
-
if len(components) < 3 or components[Const.CONSTRUCT_NAME_INDEX].isdigit():
|
|
40
|
-
logger.warning("key in construct.json is shorter than 3 parts or not name valid.")
|
|
41
|
-
else:
|
|
42
|
-
# 重复name,如Functional.add.add.X ward
|
|
43
|
-
duplicated_components = components[:Const.CONSTRUCT_NAME_INDEX + 1] + components[Const.CONSTRUCT_NAME_INDEX:]
|
|
44
|
-
return duplicated_components
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def modify_mapping_with_stack(stack, construct):
|
|
48
|
-
if not stack or not construct:
|
|
49
|
-
return {}
|
|
50
|
-
|
|
51
|
-
# 是否是mindspore的数据结构
|
|
52
|
-
is_ms = any("Cell" in ii for ii in construct)
|
|
53
|
-
# 调整后的mapping结构
|
|
54
|
-
final_pres = {}
|
|
55
|
-
# 查看归属关系
|
|
56
|
-
for key in construct:
|
|
57
|
-
key_components = key.split(Const.SEP)
|
|
58
|
-
code_list = stack.get(key, None)
|
|
59
|
-
parent_node = construct.get(key, None)
|
|
60
|
-
# 名称如果非标准开头,转为标准开头
|
|
61
|
-
if not key.startswith(("Module", "Cell")):
|
|
62
|
-
# 如果没有拿到父属scope name,默认顶级域名为Module或Cell
|
|
63
|
-
if not parent_node:
|
|
64
|
-
# 将节点名字转为标准的Module或Cell
|
|
65
|
-
key_components[0] = "Cell" if is_ms else "Module"
|
|
66
|
-
# 重复该节点的名字作为类型 如add.add add在-3位置
|
|
67
|
-
duplicated_components = get_duplicated_name(key_components)
|
|
68
|
-
modified_key = Const.SEP.join(duplicated_components)
|
|
69
|
-
|
|
70
|
-
modified_key = modified_key.replace(".forward", "").replace(".backward", "")
|
|
71
|
-
final_pres[modified_key] = {Const.ORIGIN_DATA: key, Const.SCOPE: None, Const.STACK: None}
|
|
72
|
-
continue
|
|
73
|
-
parent = parent_node.split(Const.SEP)
|
|
74
|
-
if len(parent) < 4:
|
|
75
|
-
logger.info(f"Parent name in construct.json is not valid")
|
|
76
|
-
continue
|
|
77
|
-
parent_idx = Const.NAME_FIRST_POSSIBLE_INDEX if not \
|
|
78
|
-
parent[Const.NAME_FIRST_POSSIBLE_INDEX].isdigit() else Const.NAME_SECOND_POSSIBLE_INDEX
|
|
79
|
-
parent_name = parent[parent_idx]
|
|
80
|
-
|
|
81
|
-
if code_list:
|
|
82
|
-
# {name}.Class.count_number.X ward Or {name}.Class.count_number.X ward.ele_number
|
|
83
|
-
if parent_name.endswith('s'):
|
|
84
|
-
parent_name = parent_name[:-1]
|
|
85
|
-
if len(key_components) < 3:
|
|
86
|
-
logger.info("The length of key in construct is less than 3, please check")
|
|
87
|
-
continue
|
|
88
|
-
# {name}.count_number.X ward
|
|
89
|
-
func_name = key_components[-3]
|
|
90
|
-
start_pos, end_pos = find_regard_scope(code_list, func_name, parent_name)
|
|
91
|
-
|
|
92
|
-
# 获取指定范围的代码
|
|
93
|
-
regard_scope = code_list[start_pos:end_pos]
|
|
94
|
-
|
|
95
|
-
func_stack_list = find_stack_func_list(regard_scope)
|
|
96
|
-
else:
|
|
97
|
-
func_stack_list = []
|
|
98
|
-
# 组合逻辑:parent的节点名(到节点名字为止)加上调用栈名[reversed_list]加上原来key重复key的节点名[key_components[1:-2] + key_components[-3:]]
|
|
99
|
-
final_res_key = Const.SEP.join(parent[:parent_idx + 1] + func_stack_list +
|
|
100
|
-
key_components[1:Const.CONSTRUCT_NAME_INDEX + 1] + key_components[Const.CONSTRUCT_NAME_INDEX:])
|
|
101
|
-
final_res_key = final_res_key.strip(".forward").strip(".backward")
|
|
102
|
-
else:
|
|
103
|
-
final_res_key = Const.SEP.join(key_components[:-2] + [key_components[-1]])
|
|
104
|
-
func_stack_list = []
|
|
105
|
-
final_pres[final_res_key] = {Const.ORIGIN_DATA: key, Const.SCOPE: parent_node,
|
|
106
|
-
Const.STACK: Const.SEP.join(func_stack_list) if func_stack_list else None}
|
|
107
|
-
return final_pres
|
|
@@ -1,57 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
from msprobe.mindspore.common.const import Const, FreeBenchmarkConst
|
|
17
|
-
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
18
|
-
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
19
|
-
from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory
|
|
20
|
-
from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class ForwardSelfChecker:
|
|
24
|
-
|
|
25
|
-
def __init__(self, api_name: str):
|
|
26
|
-
self.api_name = api_name
|
|
27
|
-
|
|
28
|
-
def handle(self, params: HandlerParams):
|
|
29
|
-
"""
|
|
30
|
-
装饰器实际执行逻辑
|
|
31
|
-
|
|
32
|
-
"""
|
|
33
|
-
perturbation = PerturbationFactory.create(self.api_name)
|
|
34
|
-
params.fuzzed_result = perturbation.handle(params)
|
|
35
|
-
params.original_result = params.original_func(*params.args, **params.kwargs)
|
|
36
|
-
if params.fuzzed_result is not False:
|
|
37
|
-
return self.deal_fuzzed_and_original_result(params)
|
|
38
|
-
return params.original_result
|
|
39
|
-
|
|
40
|
-
def get_compare_data(self, params: HandlerParams):
|
|
41
|
-
if self.api_name not in Const.COMMUNICATION_API_LIST:
|
|
42
|
-
return
|
|
43
|
-
# 以下为通讯类api处理逻辑
|
|
44
|
-
params.fuzzed_result = params.fuzzed_value
|
|
45
|
-
if Config.pert_type == FreeBenchmarkConst.IMPROVE_PRECISION:
|
|
46
|
-
params.original_result = params.args
|
|
47
|
-
else:
|
|
48
|
-
params.original_result = params.args[params.index]
|
|
49
|
-
|
|
50
|
-
def deal_fuzzed_and_original_result(self, params: HandlerParams):
|
|
51
|
-
original_result = params.original_result
|
|
52
|
-
self.get_compare_data(params)
|
|
53
|
-
handler = HandlerFactory.create(self.api_name)
|
|
54
|
-
result = handler.handle(params)
|
|
55
|
-
if self.api_name in Const.COMMUNICATION_API_LIST:
|
|
56
|
-
result = original_result
|
|
57
|
-
return result
|
|
@@ -1,122 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import os
|
|
17
|
-
import sys
|
|
18
|
-
import traceback
|
|
19
|
-
from functools import wraps
|
|
20
|
-
from typing import Dict, List, Tuple
|
|
21
|
-
|
|
22
|
-
from mindspore import ops
|
|
23
|
-
|
|
24
|
-
from msprobe.mindspore.common.log import logger
|
|
25
|
-
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
26
|
-
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
27
|
-
from msprobe.mindspore.free_benchmark.decorator.dec_forward import ForwardSelfChecker
|
|
28
|
-
from msprobe.mindspore.runtime import Runtime
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def decorate(original_func, decorate_func, api_name=None):
|
|
32
|
-
"""
|
|
33
|
-
总装饰器
|
|
34
|
-
"""
|
|
35
|
-
@wraps(original_func)
|
|
36
|
-
def fuzz_wrapper(*args, **kwargs):
|
|
37
|
-
|
|
38
|
-
def __exec_decorate_func():
|
|
39
|
-
params = data_pre_deal(api_name, original_func, *args, **kwargs)
|
|
40
|
-
result = decorate_func(params)
|
|
41
|
-
return result
|
|
42
|
-
|
|
43
|
-
try:
|
|
44
|
-
if Runtime.rank_id == -1:
|
|
45
|
-
Runtime.rank_id = os.environ.get("RANK_ID", -1)
|
|
46
|
-
if need_wrapper_func():
|
|
47
|
-
logger.info(f"[{api_name}] is checking.")
|
|
48
|
-
return __exec_decorate_func()
|
|
49
|
-
except Exception as e:
|
|
50
|
-
logger.error(f"[{api_name}] Error: {str(e)}")
|
|
51
|
-
logger.error(f"[{api_name}] Error detail: {traceback.format_exc()}")
|
|
52
|
-
|
|
53
|
-
return original_func(*args, **kwargs)
|
|
54
|
-
|
|
55
|
-
return fuzz_wrapper
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def decorate_forward_function(func, api_name=None):
|
|
59
|
-
"""
|
|
60
|
-
前向装饰器
|
|
61
|
-
"""
|
|
62
|
-
|
|
63
|
-
if not api_name:
|
|
64
|
-
api_name = func.__name__
|
|
65
|
-
|
|
66
|
-
def forward_func(params: HandlerParams):
|
|
67
|
-
forward = ForwardSelfChecker(api_name)
|
|
68
|
-
result = forward.handle(params)
|
|
69
|
-
return result
|
|
70
|
-
|
|
71
|
-
return decorate(func, forward_func, api_name)
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def stack_depth_check() -> bool:
|
|
75
|
-
nested_depth = 1
|
|
76
|
-
frame = sys._getframe(1)
|
|
77
|
-
while frame:
|
|
78
|
-
if frame.f_code.co_name == "fuzz_wrapper":
|
|
79
|
-
nested_depth -= 1
|
|
80
|
-
if nested_depth < 0:
|
|
81
|
-
return False
|
|
82
|
-
frame = frame.f_back
|
|
83
|
-
return True
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
def get_target_arg_index(args: Tuple) -> int:
|
|
87
|
-
"""
|
|
88
|
-
类型校验
|
|
89
|
-
|
|
90
|
-
"""
|
|
91
|
-
for i, arg in enumerate(args):
|
|
92
|
-
if ops.is_tensor(arg):
|
|
93
|
-
if not ops.is_floating_point(arg):
|
|
94
|
-
continue
|
|
95
|
-
return i
|
|
96
|
-
if isinstance(arg, (List, Tuple, Dict)):
|
|
97
|
-
return i
|
|
98
|
-
return -1
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
def data_pre_deal(api_name, func, *args, **kwargs):
|
|
102
|
-
params = HandlerParams()
|
|
103
|
-
params.args = args
|
|
104
|
-
params.kwargs = kwargs
|
|
105
|
-
params.original_func = func
|
|
106
|
-
index = get_target_arg_index(args)
|
|
107
|
-
if index == -1:
|
|
108
|
-
raise Exception(f"{api_name} has no supported input type")
|
|
109
|
-
params.index = index
|
|
110
|
-
return params
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
def need_wrapper_func():
|
|
114
|
-
if not (Runtime.is_running and Config.is_enable):
|
|
115
|
-
return False
|
|
116
|
-
if not stack_depth_check():
|
|
117
|
-
return False
|
|
118
|
-
if Config.steps and Runtime.step_count not in Config.steps:
|
|
119
|
-
return False
|
|
120
|
-
if Config.ranks and Runtime.rank_id != -1 and Runtime.rank_id not in Config.ranks:
|
|
121
|
-
return False
|
|
122
|
-
return True
|
|
@@ -1,84 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import torch
|
|
17
|
-
import torch.nn as nn
|
|
18
|
-
from msprobe.core.common.const import Const
|
|
19
|
-
from msprobe.core.common.exceptions import MsprobeException
|
|
20
|
-
from msprobe.core.data_dump.scope import BaseScope
|
|
21
|
-
from msprobe.pytorch.common.log import logger
|
|
22
|
-
from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger
|
|
23
|
-
from msprobe.pytorch.hook_module.api_registry import api_register
|
|
24
|
-
from msprobe.pytorch.service import torch_version_above_or_equal_2
|
|
25
|
-
|
|
26
|
-
hook_handle_list = []
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def module_dump(module, dump_name):
|
|
30
|
-
if not isinstance(module, nn.Module):
|
|
31
|
-
logger.error("The parameter module in module_dump must be a Module subclass.")
|
|
32
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
33
|
-
if not isinstance(dump_name, str):
|
|
34
|
-
logger.error("The parameter dump_name in module_dump must be a str type.")
|
|
35
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
36
|
-
|
|
37
|
-
api_register.api_originality()
|
|
38
|
-
register_hook(module, dump_name)
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def module_dump_end():
|
|
42
|
-
api_register.api_modularity()
|
|
43
|
-
remove_hook()
|
|
44
|
-
hook_handle_list.clear()
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def register_hook(module, dump_name):
|
|
48
|
-
prefix = BaseScope.Module_Type_Module + Const.SEP + dump_name + Const.SEP + module.__class__.__name__ + Const.SEP
|
|
49
|
-
|
|
50
|
-
pdg = PrecisionDebugger()
|
|
51
|
-
_, forward_hook, backward_hook, forward_hook_torch_version_below_2 = \
|
|
52
|
-
pdg.service.build_hook(BaseScope.Module_Type_Module, prefix)
|
|
53
|
-
|
|
54
|
-
if torch_version_above_or_equal_2:
|
|
55
|
-
forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
56
|
-
hook_handle_list.append(forward_hook_handle)
|
|
57
|
-
else:
|
|
58
|
-
pdg.service.check_register_full_backward_hook(module)
|
|
59
|
-
full_backward_hook_handle = module.register_full_backward_hook(
|
|
60
|
-
pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
61
|
-
forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2)
|
|
62
|
-
hook_handle_list.extend([full_backward_hook_handle, forward_hook_handle])
|
|
63
|
-
pdg.service.check_register_full_backward_hook(module)
|
|
64
|
-
full_backward_hook_handle = module.register_full_backward_hook(backward_hook)
|
|
65
|
-
|
|
66
|
-
forward_pre_hook_handle = module.register_forward_pre_hook(
|
|
67
|
-
pdg.service.module_processor.node_hook(prefix + Const.FORWARD, Const.START))
|
|
68
|
-
forward_hook_handle = module.register_forward_hook(
|
|
69
|
-
pdg.service.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
|
|
70
|
-
hook_handle_list.extend([full_backward_hook_handle, forward_pre_hook_handle, forward_hook_handle])
|
|
71
|
-
|
|
72
|
-
if torch_version_above_or_equal_2:
|
|
73
|
-
backward_pre_hook_handle = module.register_full_backward_pre_hook(
|
|
74
|
-
pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
|
|
75
|
-
pdg.service.check_register_full_backward_hook(module)
|
|
76
|
-
full_backward_hook_handle = module.register_full_backward_hook(
|
|
77
|
-
pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
78
|
-
hook_handle_list.extend([backward_pre_hook_handle, full_backward_hook_handle])
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
def remove_hook():
|
|
82
|
-
for hook_handle in hook_handle_list:
|
|
83
|
-
if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
|
|
84
|
-
hook_handle.remove()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|