mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -14,19 +14,29 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os.path
|
|
17
|
+
|
|
17
18
|
import torch
|
|
19
|
+
|
|
18
20
|
from msprobe.core.common.const import FileCheckConst
|
|
19
|
-
from msprobe.pytorch.common.log import logger
|
|
20
21
|
from msprobe.core.common.exceptions import FileCheckException
|
|
21
|
-
from msprobe.core.compare.acc_compare import Comparator
|
|
22
|
-
from msprobe.core.common.utils import check_configuration_param, check_compare_param, \
|
|
23
|
-
CompareException, set_dump_path, get_dump_mode
|
|
24
22
|
from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml
|
|
23
|
+
from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
|
|
24
|
+
set_dump_path
|
|
25
|
+
from msprobe.core.compare.acc_compare import Comparator, ModeConfig
|
|
26
|
+
from msprobe.core.compare.utils import set_stack_json_path
|
|
27
|
+
from msprobe.pytorch.common.log import logger
|
|
25
28
|
from msprobe.pytorch.common.utils import load_pt
|
|
26
29
|
|
|
27
30
|
|
|
28
|
-
class PTComparator
|
|
29
|
-
def __init__(self, data_mapping=None):
|
|
31
|
+
class PTComparator(Comparator):
|
|
32
|
+
def __init__(self, mode_config, data_mapping=None):
|
|
33
|
+
super().__init__(mode_config)
|
|
34
|
+
|
|
35
|
+
self.stack_mode = mode_config.stack_mode
|
|
36
|
+
self.auto_analyze = mode_config.auto_analyze
|
|
37
|
+
self.fuzzy_match = mode_config.fuzzy_match
|
|
38
|
+
self.dump_mode = mode_config.dump_mode
|
|
39
|
+
|
|
30
40
|
self.frame_name = PTComparator.__name__
|
|
31
41
|
self.data_mapping = data_mapping
|
|
32
42
|
if isinstance(self.data_mapping, str) or self.data_mapping is None:
|
|
@@ -37,23 +47,24 @@ class PTComparator (Comparator):
|
|
|
37
47
|
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
38
48
|
f"{type(self.data_mapping)}")
|
|
39
49
|
|
|
40
|
-
|
|
50
|
+
@staticmethod
|
|
51
|
+
def load_mapping_file(mapping_file):
|
|
41
52
|
if isinstance(mapping_file, str):
|
|
42
53
|
mapping_dict = load_yaml(mapping_file)
|
|
43
54
|
else:
|
|
44
55
|
mapping_dict = {}
|
|
45
56
|
return mapping_dict
|
|
46
|
-
|
|
57
|
+
|
|
47
58
|
def read_npy_data(self, dir_path, file_name):
|
|
48
59
|
if not file_name:
|
|
49
60
|
return None
|
|
50
61
|
data_path = os.path.join(dir_path, file_name)
|
|
51
62
|
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
52
|
-
|
|
63
|
+
FileCheckConst.PT_SUFFIX, False)
|
|
53
64
|
data_path = path_checker.common_check()
|
|
54
65
|
try:
|
|
55
|
-
|
|
56
|
-
|
|
66
|
+
# detach because numpy can not process gradient information
|
|
67
|
+
data_value = load_pt(data_path, to_cpu=True).detach()
|
|
57
68
|
except RuntimeError as e:
|
|
58
69
|
# 这里捕获 load_pt 中抛出的异常
|
|
59
70
|
logger.error(f"Failed to load the .pt file at {data_path}.")
|
|
@@ -65,20 +76,29 @@ class PTComparator (Comparator):
|
|
|
65
76
|
if data_value.dtype == torch.bfloat16:
|
|
66
77
|
data_value = data_value.to(torch.float32)
|
|
67
78
|
data_value = data_value.numpy()
|
|
68
|
-
return data_value
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def compare(input_param, output_path,
|
|
79
|
+
return data_value
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def compare(input_param, output_path, **kwargs):
|
|
72
83
|
try:
|
|
84
|
+
auto_analyze = kwargs.get('auto_analyze', True)
|
|
85
|
+
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
86
|
+
data_mapping = kwargs.get('data_mapping', None)
|
|
87
|
+
suffix = kwargs.get('suffix', '')
|
|
88
|
+
|
|
73
89
|
set_dump_path(input_param)
|
|
74
90
|
dump_mode = get_dump_mode(input_param)
|
|
91
|
+
if "stack_json_path" in input_param:
|
|
92
|
+
stack_mode = kwargs.get('stack_mode', False)
|
|
93
|
+
else:
|
|
94
|
+
stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param
|
|
75
95
|
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
|
|
76
96
|
create_directory(output_path)
|
|
77
|
-
check_compare_param(input_param, output_path, dump_mode)
|
|
78
|
-
data_mapping = kwargs.get('data_mapping', None)
|
|
97
|
+
check_compare_param(input_param, output_path, dump_mode, stack_mode)
|
|
79
98
|
except (CompareException, FileCheckException) as error:
|
|
80
99
|
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
81
100
|
raise CompareException(error.code) from error
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
101
|
+
|
|
102
|
+
mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode)
|
|
103
|
+
pt_comparator = PTComparator(mode_config, data_mapping)
|
|
104
|
+
pt_comparator.compare_core(input_param, output_path, suffix=suffix)
|
|
@@ -34,6 +34,7 @@ class DebuggerConfig:
|
|
|
34
34
|
self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
|
|
35
35
|
self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
|
|
36
36
|
self.framework = Const.PT_FRAMEWORK
|
|
37
|
+
self.async_dump = common_config.async_dump if common_config.async_dump else False
|
|
37
38
|
|
|
38
39
|
if self.level == Const.LEVEL_L2:
|
|
39
40
|
self.is_backward_kernel_dump = False
|
|
@@ -74,29 +75,43 @@ class DebuggerConfig:
|
|
|
74
75
|
if not self.dump_path:
|
|
75
76
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
76
77
|
f"The dump_path not found.")
|
|
78
|
+
if not isinstance(self.async_dump, bool):
|
|
79
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
80
|
+
f"The parameters async_dump should be bool.")
|
|
77
81
|
|
|
78
82
|
def check(self):
|
|
79
83
|
self.check_kwargs()
|
|
80
84
|
return True
|
|
81
85
|
|
|
82
86
|
def check_model(self, instance, start_model):
|
|
83
|
-
if self.level not in [
|
|
87
|
+
if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
|
|
84
88
|
if instance.model is not None or start_model is not None:
|
|
85
|
-
logger.
|
|
89
|
+
logger.info_on_rank_0(
|
|
86
90
|
f"The current level is not L0 or mix level, so the model parameters will not be used.")
|
|
87
91
|
return
|
|
88
|
-
if start_model is None:
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
if isinstance(
|
|
95
|
-
|
|
92
|
+
if start_model is None and instance.model is None:
|
|
93
|
+
logger.error_on_rank_0(
|
|
94
|
+
f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' parameter.")
|
|
95
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
|
|
96
|
+
|
|
97
|
+
instance.model = start_model if start_model is not None else instance.model
|
|
98
|
+
if isinstance(instance.model, torch.nn.Module):
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
error_model = None
|
|
102
|
+
if isinstance(instance.model, (list, tuple)):
|
|
103
|
+
for model in instance.model:
|
|
104
|
+
if not isinstance(model, torch.nn.Module):
|
|
105
|
+
error_model = model
|
|
106
|
+
break
|
|
96
107
|
else:
|
|
97
|
-
|
|
108
|
+
error_model = instance.model
|
|
109
|
+
|
|
110
|
+
if error_model is not None:
|
|
111
|
+
error_info = (f"The 'model' parameter must be a torch.nn.Moudle or list[torch.nn.Moudle] "
|
|
112
|
+
f"type, currently there is a {type(error_model)} type.")
|
|
98
113
|
raise MsprobeException(
|
|
99
|
-
MsprobeException.INVALID_PARAM_ERROR,
|
|
114
|
+
MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
100
115
|
|
|
101
116
|
def _check_and_adjust_config_with_l2(self):
|
|
102
117
|
if self.scope:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -22,6 +22,7 @@ from msprobe.core.common.file_utils import FileChecker
|
|
|
22
22
|
from msprobe.core.common.utils import get_real_step_or_rank
|
|
23
23
|
from msprobe.pytorch.common.log import logger
|
|
24
24
|
from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
|
|
25
|
+
from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper
|
|
25
26
|
from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
|
|
26
27
|
from msprobe.pytorch.pt_config import parse_json_config
|
|
27
28
|
from msprobe.pytorch.service import Service
|
|
@@ -49,7 +50,7 @@ class PrecisionDebugger:
|
|
|
49
50
|
dump_path=None,
|
|
50
51
|
level=None,
|
|
51
52
|
model=None,
|
|
52
|
-
step=None
|
|
53
|
+
step=None
|
|
53
54
|
):
|
|
54
55
|
if not hasattr(self, "initialized"):
|
|
55
56
|
config_params = ConfigParameters(config_path,
|
|
@@ -59,7 +60,6 @@ class PrecisionDebugger:
|
|
|
59
60
|
model)
|
|
60
61
|
self.check_input_params(config_params)
|
|
61
62
|
|
|
62
|
-
self.api_origin = False
|
|
63
63
|
self.initialized = True
|
|
64
64
|
self.model = model
|
|
65
65
|
common_config, task_config = parse_json_config(config_path, task)
|
|
@@ -67,12 +67,13 @@ class PrecisionDebugger:
|
|
|
67
67
|
if self.task == Const.GRAD_PROBE:
|
|
68
68
|
self.gm = GradientMonitor(common_config, task_config)
|
|
69
69
|
return
|
|
70
|
-
if step:
|
|
70
|
+
if step is not None:
|
|
71
71
|
common_config.step = get_real_step_or_rank(step, Const.STEP)
|
|
72
72
|
self.config = DebuggerConfig(
|
|
73
73
|
common_config, task_config, task, dump_path, level
|
|
74
74
|
)
|
|
75
75
|
self.service = Service(self.config)
|
|
76
|
+
self.module_dumper = ModuleDumper(self.service)
|
|
76
77
|
self.enable_dataloader = self.config.enable_dataloader
|
|
77
78
|
if self.enable_dataloader:
|
|
78
79
|
logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
|
|
@@ -105,9 +106,11 @@ class PrecisionDebugger:
|
|
|
105
106
|
raise MsprobeException(
|
|
106
107
|
MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
|
|
107
108
|
|
|
108
|
-
if args.model is not None
|
|
109
|
-
|
|
110
|
-
|
|
109
|
+
if args.model is not None:
|
|
110
|
+
logger.warning_on_rank_0(
|
|
111
|
+
"The 'model' parameter in the PrecisionDebugger will be deprecated in the future."
|
|
112
|
+
"It is recommended to pass the 'model' parameter in the start interface instead."
|
|
113
|
+
)
|
|
111
114
|
|
|
112
115
|
@classmethod
|
|
113
116
|
def start(cls, model=None):
|
|
@@ -120,15 +123,12 @@ class PrecisionDebugger:
|
|
|
120
123
|
if instance.enable_dataloader:
|
|
121
124
|
logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
|
|
122
125
|
else:
|
|
123
|
-
instance.service.start(instance.model
|
|
124
|
-
instance.api_origin = False
|
|
126
|
+
instance.service.start(instance.model)
|
|
125
127
|
|
|
126
|
-
# 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump
|
|
127
128
|
@classmethod
|
|
128
129
|
def forward_backward_dump_end(cls):
|
|
129
130
|
instance = cls._instance
|
|
130
|
-
instance.
|
|
131
|
-
instance.api_origin = True
|
|
131
|
+
instance.stop()
|
|
132
132
|
|
|
133
133
|
@classmethod
|
|
134
134
|
def stop(cls):
|
|
@@ -159,6 +159,36 @@ class PrecisionDebugger:
|
|
|
159
159
|
cls._instance.gm.monitor(model)
|
|
160
160
|
|
|
161
161
|
|
|
162
|
+
def module_dump(module, dump_name):
|
|
163
|
+
if not isinstance(module, torch.nn.Module):
|
|
164
|
+
raise MsprobeException(
|
|
165
|
+
MsprobeException.INVALID_PARAM_ERROR,
|
|
166
|
+
f"the module argument in module_dump must be a torch.nn.Module subclass"
|
|
167
|
+
)
|
|
168
|
+
if not isinstance(dump_name, str):
|
|
169
|
+
raise MsprobeException(
|
|
170
|
+
MsprobeException.INVALID_PARAM_ERROR,
|
|
171
|
+
f"the dump_name argument in module_dump must be a str type"
|
|
172
|
+
)
|
|
173
|
+
instance = PrecisionDebugger._instance
|
|
174
|
+
if not instance:
|
|
175
|
+
raise MsprobeException(
|
|
176
|
+
MsprobeException.INTERFACE_USAGE_ERROR,
|
|
177
|
+
f"PrecisionDebugger must be instantiated before using module_dump interface"
|
|
178
|
+
)
|
|
179
|
+
instance.module_dumper.start_module_dump(module, dump_name)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def module_dump_end():
|
|
183
|
+
instance = PrecisionDebugger._instance
|
|
184
|
+
if not instance:
|
|
185
|
+
raise MsprobeException(
|
|
186
|
+
MsprobeException.INTERFACE_USAGE_ERROR,
|
|
187
|
+
f"PrecisionDebugger must be instantiated before using module_dump_end interface"
|
|
188
|
+
)
|
|
189
|
+
instance.module_dumper.stop_module_dump()
|
|
190
|
+
|
|
191
|
+
|
|
162
192
|
def iter_tracer(func):
|
|
163
193
|
def func_wrapper(*args, **kwargs):
|
|
164
194
|
debugger_instance = PrecisionDebugger.instance
|
|
File without changes
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.core.data_dump.scope import BaseScope
|
|
19
|
+
from msprobe.pytorch.common.log import logger
|
|
20
|
+
from msprobe.pytorch.hook_module.api_registry import api_register
|
|
21
|
+
|
|
22
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ModuleDumper:
|
|
26
|
+
def __init__(self, service):
|
|
27
|
+
self.service = service
|
|
28
|
+
self.hook_handle_list = []
|
|
29
|
+
|
|
30
|
+
def start_module_dump(self, module, dump_name):
|
|
31
|
+
api_register.api_originality()
|
|
32
|
+
self.register_hook(module, dump_name)
|
|
33
|
+
|
|
34
|
+
def stop_module_dump(self):
|
|
35
|
+
api_register.api_modularity()
|
|
36
|
+
for hook_handle in self.hook_handle_list:
|
|
37
|
+
if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
|
|
38
|
+
hook_handle.remove()
|
|
39
|
+
self.hook_handle_list.clear()
|
|
40
|
+
|
|
41
|
+
def register_hook(self, module, dump_name):
|
|
42
|
+
prefix_name = (
|
|
43
|
+
BaseScope.Module_Type_Module + Const.SEP +
|
|
44
|
+
dump_name + Const.SEP +
|
|
45
|
+
module.__class__.__name__ + Const.SEP
|
|
46
|
+
)
|
|
47
|
+
module_processor = self.service.module_processor
|
|
48
|
+
_, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.service.build_hook(
|
|
49
|
+
BaseScope.Module_Type_Module,
|
|
50
|
+
prefix_name
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
if module_processor.has_register_backward_hook(module):
|
|
54
|
+
logger.warning(
|
|
55
|
+
f"The {dump_name} module has registered deprecated register_backward_hook,"
|
|
56
|
+
f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
|
|
57
|
+
)
|
|
58
|
+
if torch_version_above_or_equal_2:
|
|
59
|
+
forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
60
|
+
else:
|
|
61
|
+
if not module_processor.has_register_backward_hook(module):
|
|
62
|
+
backward_hook_handle = module.register_full_backward_hook(
|
|
63
|
+
module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
|
|
64
|
+
)
|
|
65
|
+
self.hook_handle_list.append(backward_hook_handle)
|
|
66
|
+
forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2)
|
|
67
|
+
self.hook_handle_list.append(forward_hook_handle)
|
|
68
|
+
if not module_processor.has_register_backward_hook(module):
|
|
69
|
+
backward_hook_handle = module.register_full_backward_hook(backward_hook)
|
|
70
|
+
self.hook_handle_list.append(backward_hook_handle)
|
|
71
|
+
|
|
72
|
+
forward_pre_hook_handle = module.register_forward_pre_hook(
|
|
73
|
+
module_processor.node_hook(prefix_name + Const.FORWARD, Const.START)
|
|
74
|
+
)
|
|
75
|
+
forward_hook_handle = module.register_forward_hook(
|
|
76
|
+
module_processor.node_hook(prefix_name + Const.FORWARD, Const.STOP)
|
|
77
|
+
)
|
|
78
|
+
self.hook_handle_list.extend([forward_pre_hook_handle, forward_hook_handle])
|
|
79
|
+
if torch_version_above_or_equal_2 and not module_processor.has_register_backward_hook(module):
|
|
80
|
+
backward_pre_hook_handle = module.register_full_backward_pre_hook(
|
|
81
|
+
module_processor.node_hook(prefix_name + Const.BACKWARD, Const.START)
|
|
82
|
+
)
|
|
83
|
+
backward_hook_handle = module.register_full_backward_hook(
|
|
84
|
+
module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
|
|
85
|
+
)
|
|
86
|
+
self.hook_handle_list.extend([backward_pre_hook_handle, backward_hook_handle])
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -17,12 +17,24 @@ from functools import wraps
|
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
from msprobe.core.common.const import Const
|
|
20
|
-
from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope
|
|
20
|
+
from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
|
|
21
|
+
from msprobe.pytorch.common.log import logger
|
|
22
|
+
from torch.utils.checkpoint import checkpoint as origin_checkpoint
|
|
23
|
+
from torch.utils.checkpoint import set_checkpoint_early_stop
|
|
21
24
|
from torch.utils.hooks import BackwardHook
|
|
22
25
|
|
|
23
26
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
24
27
|
|
|
25
28
|
|
|
29
|
+
def checkpoint_without_early_stop(*args, **kwargs):
|
|
30
|
+
with set_checkpoint_early_stop(False):
|
|
31
|
+
return origin_checkpoint(*args, **kwargs)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def replace_checkpoint():
|
|
35
|
+
torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
|
|
36
|
+
|
|
37
|
+
|
|
26
38
|
class ModuleProcesser:
|
|
27
39
|
module_count = {}
|
|
28
40
|
module_stack = []
|
|
@@ -34,6 +46,7 @@ class ModuleProcesser:
|
|
|
34
46
|
BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
|
|
35
47
|
BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
|
|
36
48
|
BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook)
|
|
49
|
+
replace_checkpoint()
|
|
37
50
|
|
|
38
51
|
@staticmethod
|
|
39
52
|
def filter_tensor_and_tuple(func):
|
|
@@ -63,7 +76,7 @@ class ModuleProcesser:
|
|
|
63
76
|
return ModuleProcesser.clone_if_tensor(result)
|
|
64
77
|
|
|
65
78
|
return clone_return_value_func
|
|
66
|
-
|
|
79
|
+
|
|
67
80
|
@staticmethod
|
|
68
81
|
def clone_if_tensor(result):
|
|
69
82
|
if isinstance(result, torch.Tensor):
|
|
@@ -85,6 +98,22 @@ class ModuleProcesser:
|
|
|
85
98
|
ModuleProcesser.module_count[module_name] += 1
|
|
86
99
|
return ModuleProcesser.module_count[module_name]
|
|
87
100
|
|
|
101
|
+
@staticmethod
|
|
102
|
+
def has_register_backward_hook(module):
|
|
103
|
+
return hasattr(module, '_backward_hooks') and \
|
|
104
|
+
len(module._backward_hooks) > 0 and \
|
|
105
|
+
module._is_full_backward_hook is False
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def get_modules_and_names(models):
|
|
109
|
+
modules_and_names_with_index = {}
|
|
110
|
+
if isinstance(models, (list, tuple)):
|
|
111
|
+
for index, model in enumerate(models):
|
|
112
|
+
modules_and_names_with_index[str(index)] = model.named_modules()
|
|
113
|
+
else:
|
|
114
|
+
modules_and_names_with_index["-1"] = models.named_modules()
|
|
115
|
+
return modules_and_names_with_index
|
|
116
|
+
|
|
88
117
|
@classmethod
|
|
89
118
|
def reset_module_stats(cls):
|
|
90
119
|
cls.module_count = {}
|
|
@@ -92,6 +121,42 @@ class ModuleProcesser:
|
|
|
92
121
|
cls.api_parent_node = ""
|
|
93
122
|
cls.module_node = {}
|
|
94
123
|
|
|
124
|
+
def register_module_hook(self, models, build_hook):
|
|
125
|
+
logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.")
|
|
126
|
+
modules_and_names_with_index = self.get_modules_and_names(models)
|
|
127
|
+
for index, modules_and_names in modules_and_names_with_index.items():
|
|
128
|
+
model = models if index == "-1" else models[int(index)]
|
|
129
|
+
for name, module in modules_and_names:
|
|
130
|
+
if module == model:
|
|
131
|
+
continue
|
|
132
|
+
module_index = (index + Const.SEP) if index != "-1" else ""
|
|
133
|
+
prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index +
|
|
134
|
+
name + Const.SEP + module.__class__.__name__ + Const.SEP)
|
|
135
|
+
pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = build_hook(
|
|
136
|
+
BaseScope.Module_Type_Module,
|
|
137
|
+
prefix_name
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if self.has_register_backward_hook(module):
|
|
141
|
+
logger.warning(
|
|
142
|
+
f"The {prefix_name[:-1]} has registered deprecated register_backward_hook,"
|
|
143
|
+
f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
|
|
144
|
+
)
|
|
145
|
+
if torch_version_above_or_equal_2:
|
|
146
|
+
module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
147
|
+
else:
|
|
148
|
+
if not self.has_register_backward_hook(module):
|
|
149
|
+
module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP))
|
|
150
|
+
module.register_forward_hook(forward_hook_torch_version_below_2)
|
|
151
|
+
if not self.has_register_backward_hook(module):
|
|
152
|
+
module.register_full_backward_hook(backward_hook)
|
|
153
|
+
|
|
154
|
+
module.register_forward_pre_hook(self.node_hook(prefix_name + Const.FORWARD, Const.START))
|
|
155
|
+
module.register_forward_hook(self.node_hook(prefix_name + Const.FORWARD, Const.STOP))
|
|
156
|
+
if torch_version_above_or_equal_2 and not self.has_register_backward_hook(module):
|
|
157
|
+
module.register_full_backward_pre_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.START))
|
|
158
|
+
module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP))
|
|
159
|
+
|
|
95
160
|
def node_hook(self, name_prefix, start_or_stop, **kwargs):
|
|
96
161
|
|
|
97
162
|
def pre_hook(module, input, output=None):
|
|
@@ -100,7 +165,10 @@ class ModuleProcesser:
|
|
|
100
165
|
except IndexError as e:
|
|
101
166
|
index = None
|
|
102
167
|
pass
|
|
103
|
-
|
|
168
|
+
full_name = name_prefix + Const.SEP + str(index)
|
|
169
|
+
if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
|
|
170
|
+
module.mindstudio_reserved_name = []
|
|
171
|
+
module.mindstudio_reserved_name.append(full_name)
|
|
104
172
|
if self.module_stack:
|
|
105
173
|
ModuleProcesser.module_node[full_name] = self.module_stack[-1]
|
|
106
174
|
else:
|
|
@@ -119,8 +187,11 @@ class ModuleProcesser:
|
|
|
119
187
|
ModuleProcesser.api_parent_node = self.module_stack[-1]
|
|
120
188
|
else:
|
|
121
189
|
ModuleProcesser.api_parent_node = None
|
|
190
|
+
if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
|
|
191
|
+
raise RuntimeError(f"module reserve name is None when pop")
|
|
192
|
+
current_name = module.mindstudio_reserved_name.pop()
|
|
122
193
|
if self.scope:
|
|
123
|
-
self.scope.end_module(
|
|
194
|
+
self.scope.end_module(current_name)
|
|
124
195
|
|
|
125
196
|
def backward_hook(module, input, output=None):
|
|
126
197
|
try:
|
|
@@ -128,7 +199,10 @@ class ModuleProcesser:
|
|
|
128
199
|
except IndexError as e:
|
|
129
200
|
index = None
|
|
130
201
|
pass
|
|
131
|
-
|
|
202
|
+
full_name = name_prefix + Const.SEP + str(index)
|
|
203
|
+
if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
|
|
204
|
+
module.mindstudio_reserved_name = []
|
|
205
|
+
module.mindstudio_reserved_name.append(full_name)
|
|
132
206
|
forward_full_name = full_name.replace(Const.BACKWARD, Const.FORWARD)
|
|
133
207
|
ModuleProcesser.module_node[full_name] = ModuleProcesser.module_node[forward_full_name].replace(
|
|
134
208
|
Const.FORWARD, Const.BACKWARD) if ModuleProcesser.module_node[forward_full_name] else None
|
|
@@ -39,7 +39,6 @@ class DataParams:
|
|
|
39
39
|
origin_func: Optional[Callable] = None
|
|
40
40
|
api_type: Optional[str] = None
|
|
41
41
|
fuzz_stage: Optional[str] = None
|
|
42
|
-
grad_unequal_flag: Optional[bool] = True
|
|
43
42
|
|
|
44
43
|
|
|
45
44
|
@dataclass
|
|
@@ -127,6 +126,8 @@ def make_unequal_row(
|
|
|
127
126
|
)
|
|
128
127
|
if isinstance(ratio, float):
|
|
129
128
|
row.max_rel = ratio - 1
|
|
129
|
+
if isinstance(ratio, str):
|
|
130
|
+
row.max_rel = ratio
|
|
130
131
|
origin_tensor = data_params.original_result
|
|
131
132
|
perturbed_tensor = data_params.perturbed_result
|
|
132
133
|
if index is not None:
|
|
@@ -124,6 +124,7 @@ class TorchC:
|
|
|
124
124
|
abs = torch._C._VariableFunctionsClass.abs
|
|
125
125
|
where = torch._C._VariableFunctionsClass.where
|
|
126
126
|
div = torch._C._VariableFunctionsClass.div
|
|
127
|
+
mul = torch._C._VariableFunctionsClass.mul
|
|
127
128
|
max = torch._C._VariableFunctionsClass.max
|
|
128
129
|
min = torch._C._VariableFunctionsClass.min
|
|
129
130
|
gt = torch._C._VariableFunctionsClass.gt
|
|
@@ -138,3 +139,5 @@ class TorchC:
|
|
|
138
139
|
tensor_split = torch._C._VariableFunctionsClass.tensor_split
|
|
139
140
|
stack = torch._C._VariableFunctionsClass.stack
|
|
140
141
|
reshape = torch._C._VariableFunctionsClass.reshape
|
|
142
|
+
nan_to_num = torch._C._VariableFunctionsClass.nan_to_num
|
|
143
|
+
aminmax = torch._C._VariableFunctionsClass.aminmax
|
|
@@ -82,13 +82,11 @@ class GradSaver:
|
|
|
82
82
|
data_params = DataParams()
|
|
83
83
|
data_params.original_result = origin_grad
|
|
84
84
|
data_params.perturbed_result = perturbed_grad
|
|
85
|
-
data_params.grad_unequal_flag = False
|
|
86
85
|
data_params.valid_input_index = index
|
|
87
86
|
try:
|
|
88
87
|
handler.handle(data_params)
|
|
89
88
|
if not data_params.is_consistent:
|
|
90
89
|
self.is_compare = False
|
|
91
|
-
data_params.grad_unequal_flag = True
|
|
92
90
|
data_params.is_consistent = True
|
|
93
91
|
data_params.perturbed_result = self.perturbed_grad_input
|
|
94
92
|
data_params.original_result = self.origin_grad_input
|