mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
|
|
19
|
+
from msprobe.core.config_check.utils.utils import config_checking_print
|
|
20
|
+
from msprobe.core.common.file_utils import FileOpen, load_yaml
|
|
21
|
+
from msprobe.core.common.const import Const, FileCheckConst
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Parser(ABC):
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def parse(self, file_path: str) -> dict:
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
def run(self, file_path: str) -> dict:
|
|
30
|
+
"""
|
|
31
|
+
统一对外调用接口
|
|
32
|
+
:param file_path: 需解析的文件路径
|
|
33
|
+
:return:
|
|
34
|
+
"""
|
|
35
|
+
try:
|
|
36
|
+
result = self.parse(file_path)
|
|
37
|
+
except Exception as exc:
|
|
38
|
+
config_checking_print(f"{self.__class__} parsing error, skip file path: {file_path}, error: {exc}")
|
|
39
|
+
result = {}
|
|
40
|
+
return result
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ShellParser(Parser):
|
|
44
|
+
def parse(self, file_path: str) -> dict:
|
|
45
|
+
"""
|
|
46
|
+
Extracts arguments from bash script used to run a model training.
|
|
47
|
+
"""
|
|
48
|
+
hyperparameters = {}
|
|
49
|
+
script_content_list = []
|
|
50
|
+
with FileOpen(file_path, 'r') as file:
|
|
51
|
+
for line in file:
|
|
52
|
+
stripped_line = line.lstrip()
|
|
53
|
+
if not stripped_line.startswith('#'):
|
|
54
|
+
line = line.split('#')[0].rstrip() + '\n'
|
|
55
|
+
if line.strip():
|
|
56
|
+
script_content_list.append(line)
|
|
57
|
+
script_content = ''.join(script_content_list)
|
|
58
|
+
|
|
59
|
+
command_line = re.search(r'msrun\s[^|]*|torchrun\s[^|]*|python\d? -m torch.distributed.launch\s[^|]*',
|
|
60
|
+
script_content,
|
|
61
|
+
re.DOTALL)
|
|
62
|
+
if command_line:
|
|
63
|
+
command_line = command_line.group()
|
|
64
|
+
|
|
65
|
+
blocks = re.findall(r'([a-zA-Z0-9_]{1,20}_ARGS)="(.*?)"', script_content, re.DOTALL)
|
|
66
|
+
block_contents = {}
|
|
67
|
+
for block_name, block_content in blocks:
|
|
68
|
+
block_content = block_content.replace('\n', ' ')
|
|
69
|
+
block_contents[block_name] = block_content
|
|
70
|
+
command_line = command_line.replace(f"${block_name}", block_content)
|
|
71
|
+
|
|
72
|
+
matches = re.findall(r'--([\w-]+)(?:\s+([^\s\\]+))?', command_line)
|
|
73
|
+
for match in matches:
|
|
74
|
+
key, value = match
|
|
75
|
+
args_key = re.match(r'\$\{?(\w+)}?', value)
|
|
76
|
+
if args_key:
|
|
77
|
+
env_vars = re.findall(rf'{args_key.group(1)}=\s*(.+)', script_content)
|
|
78
|
+
if env_vars:
|
|
79
|
+
value = env_vars[-1]
|
|
80
|
+
hyperparameters[key] = value if value else True
|
|
81
|
+
|
|
82
|
+
return hyperparameters
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class YamlParser(Parser):
|
|
86
|
+
hyperparameters = {}
|
|
87
|
+
|
|
88
|
+
def parse(self, file_path: str) -> dict:
|
|
89
|
+
ori_hyper = load_yaml(file_path)
|
|
90
|
+
self.recursive_parse_parameters(ori_hyper, "")
|
|
91
|
+
return self.hyperparameters
|
|
92
|
+
|
|
93
|
+
def recursive_parse_parameters(self, parameters, prefix):
|
|
94
|
+
if isinstance(parameters, dict):
|
|
95
|
+
for key, value in parameters.items():
|
|
96
|
+
new_prefix = prefix + Const.SEP + key if prefix else key
|
|
97
|
+
self.recursive_parse_parameters(value, new_prefix)
|
|
98
|
+
elif isinstance(parameters, list):
|
|
99
|
+
for value in parameters:
|
|
100
|
+
self.recursive_parse_parameters(value, prefix)
|
|
101
|
+
elif isinstance(parameters, (int, str, bool)):
|
|
102
|
+
self.hyperparameters.update({prefix: parameters})
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class ParserFactory:
|
|
106
|
+
__ParserDict = {
|
|
107
|
+
FileCheckConst.SHELL_SUFFIX: ShellParser(),
|
|
108
|
+
FileCheckConst.YAML_SUFFIX: YamlParser()
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
def get_parser(self, file_type: str) -> Parser:
|
|
112
|
+
parser = self.__ParserDict.get(file_type, None)
|
|
113
|
+
if not parser:
|
|
114
|
+
raise ValueError(f'Invalid parser type: {file_type}')
|
|
115
|
+
return parser
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import re
|
|
18
|
+
import hashlib
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.framework_adapter import FmkAdp
|
|
21
|
+
from msprobe.core.common.log import logger
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def merge_keys(dir_0, dir_1):
|
|
25
|
+
output_list = list(dir_0.keys())
|
|
26
|
+
output_list.extend(list(dir_1.keys()))
|
|
27
|
+
return set(output_list)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def compare_dict(bench_dict, cmp_dict):
|
|
31
|
+
result = []
|
|
32
|
+
for key in set(bench_dict.keys()) | set(cmp_dict.keys()):
|
|
33
|
+
if key in bench_dict and key in cmp_dict:
|
|
34
|
+
if bench_dict[key] != cmp_dict[key]:
|
|
35
|
+
result.append(f"{key}: {bench_dict[key]} -> {cmp_dict[key]}")
|
|
36
|
+
elif key in bench_dict:
|
|
37
|
+
result.append(f"{key}: [deleted] -> {bench_dict[key]}")
|
|
38
|
+
else:
|
|
39
|
+
result.append(f"{key}: [added] -> {cmp_dict[key]}")
|
|
40
|
+
return result
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def config_checking_print(msg):
|
|
44
|
+
logger.info(f"[config checking log] {msg}")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def tensor_to_hash(tensor):
|
|
48
|
+
"""Compute the hash value of a tensor"""
|
|
49
|
+
tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes()
|
|
50
|
+
return bytes_hash(tensor_bytes)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_tensor_features(tensor):
|
|
54
|
+
features = {
|
|
55
|
+
"max": FmkAdp.tensor_max(tensor),
|
|
56
|
+
"min": FmkAdp.tensor_min(tensor),
|
|
57
|
+
"mean": FmkAdp.tensor_mean(tensor),
|
|
58
|
+
"norm": FmkAdp.tensor_norm(tensor),
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
return features
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def compare_dicts(dict1, dict2, path=''):
|
|
65
|
+
deleted = []
|
|
66
|
+
added = []
|
|
67
|
+
changed = []
|
|
68
|
+
result = {}
|
|
69
|
+
|
|
70
|
+
for key in dict1:
|
|
71
|
+
if key not in dict2:
|
|
72
|
+
deleted.append(f"[Deleted]: {path + key}")
|
|
73
|
+
result[key] = "[deleted]"
|
|
74
|
+
else:
|
|
75
|
+
if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
|
|
76
|
+
sub_deleted, sub_added, sub_changed, sub_result = compare_dicts(
|
|
77
|
+
dict1[key], dict2[key], path + key + '/')
|
|
78
|
+
deleted.extend(sub_deleted)
|
|
79
|
+
added.extend(sub_added)
|
|
80
|
+
changed.extend(sub_changed)
|
|
81
|
+
if sub_result:
|
|
82
|
+
result[key] = sub_result
|
|
83
|
+
elif dict1[key] != dict2[key]:
|
|
84
|
+
changed.append(f"[Changed]: {path + key} : {dict1[key]} -> {dict2[key]}")
|
|
85
|
+
result[key] = f"[changed]: {dict1[key]} -> {dict2[key]}"
|
|
86
|
+
for key in dict2:
|
|
87
|
+
if key not in dict1:
|
|
88
|
+
added.append(f"[Added]: {path + key}")
|
|
89
|
+
result[key] = "[added]"
|
|
90
|
+
return deleted, added, changed, result
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def bytes_hash(obj: bytes):
|
|
94
|
+
hex_dig = hashlib.sha256(obj).hexdigest()
|
|
95
|
+
short_hash = int(hex_dig, 16) % (2 ** 16)
|
|
96
|
+
return short_hash
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def update_dict(ori_dict, new_dict):
|
|
100
|
+
for key, value in new_dict.items():
|
|
101
|
+
if key in ori_dict and ori_dict[key] != value:
|
|
102
|
+
if "values" in ori_dict.keys():
|
|
103
|
+
ori_dict[key]["values"].append(new_dict[key])
|
|
104
|
+
else:
|
|
105
|
+
ori_dict[key] = {"description": "duplicate_value", "values": [ori_dict[key], new_dict[key]]}
|
|
106
|
+
else:
|
|
107
|
+
ori_dict[key] = value
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import inspect
|
|
17
|
+
from typing import Dict, Any, Optional, Callable, Union, List, Tuple
|
|
18
|
+
|
|
19
|
+
from msprobe.core.common.const import Const
|
|
20
|
+
from msprobe.core.common.file_utils import load_yaml
|
|
21
|
+
from msprobe.core.common.log import logger
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _get_attr(module, attr_name):
|
|
25
|
+
if Const.SEP in attr_name:
|
|
26
|
+
sub_module_name, sub_attr = attr_name.rsplit(Const.SEP, 1)
|
|
27
|
+
sub_module = getattr(module, sub_module_name, None)
|
|
28
|
+
attr = getattr(sub_module, sub_attr, None)
|
|
29
|
+
else:
|
|
30
|
+
attr = getattr(module, attr_name, None)
|
|
31
|
+
return attr
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ApiWrapper:
|
|
35
|
+
def __init__(
|
|
36
|
+
self, api_types: Dict[str, Dict[str, Any]],
|
|
37
|
+
api_list_paths: Union[str, List[str], Tuple[str]],
|
|
38
|
+
backlist: Union[List[str], Tuple[str]] = None
|
|
39
|
+
):
|
|
40
|
+
self.api_types = api_types
|
|
41
|
+
if not isinstance(api_list_paths, (list, tuple)):
|
|
42
|
+
api_list_paths = [api_list_paths] * len(self.api_types)
|
|
43
|
+
elif len(api_list_paths) != len(self.api_types):
|
|
44
|
+
raise RuntimeError("The number of api_list_paths must be equal to the number of frameworks in 'api_types', "
|
|
45
|
+
"when api_list_paths is a list or tuple.")
|
|
46
|
+
self.api_list_paths = api_list_paths
|
|
47
|
+
self.backlist = backlist if backlist else []
|
|
48
|
+
self.api_names = self._get_api_names()
|
|
49
|
+
self.wrapped_api_functions = dict()
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def deal_with_self_kwargs(api_name, api_func, args, kwargs):
|
|
53
|
+
if kwargs and 'self' in kwargs:
|
|
54
|
+
func_params = None
|
|
55
|
+
try:
|
|
56
|
+
func_params = inspect.signature(api_func).parameters
|
|
57
|
+
except Exception:
|
|
58
|
+
if api_name in Const.API_WITH_SELF_ARG:
|
|
59
|
+
func_params = inspect.signature(Const.API_WITH_SELF_ARG.get(api_name)).parameters
|
|
60
|
+
if func_params is None:
|
|
61
|
+
return False, args, kwargs
|
|
62
|
+
|
|
63
|
+
for name, param in func_params.items():
|
|
64
|
+
if name == 'self' and param.kind == inspect.Parameter.KEYWORD_ONLY:
|
|
65
|
+
return False, args, kwargs
|
|
66
|
+
args_ = list(args)
|
|
67
|
+
names_and_values = []
|
|
68
|
+
self_index = 0
|
|
69
|
+
for i, item in enumerate(func_params.items()):
|
|
70
|
+
names_and_values.append((item[0], item[1].default))
|
|
71
|
+
if item[0] == 'self':
|
|
72
|
+
self_index = i
|
|
73
|
+
break
|
|
74
|
+
for i in range(len(args), self_index + 1):
|
|
75
|
+
if names_and_values[i][0] in kwargs:
|
|
76
|
+
args_.append(kwargs.pop(names_and_values[i][0]))
|
|
77
|
+
else:
|
|
78
|
+
args_.append(names_and_values[i][1])
|
|
79
|
+
args = tuple(args_)
|
|
80
|
+
|
|
81
|
+
return True, args, kwargs
|
|
82
|
+
|
|
83
|
+
def wrap_api(
|
|
84
|
+
self, api_templates, hook_build_func: Optional[Callable]
|
|
85
|
+
):
|
|
86
|
+
api_types_num = sum([len(v) for v in self.api_types.values()])
|
|
87
|
+
if not isinstance(api_templates, (list, tuple)):
|
|
88
|
+
api_templates = [api_templates] * api_types_num
|
|
89
|
+
elif len(api_templates) != api_types_num:
|
|
90
|
+
raise RuntimeError("The number of api_templates must be equal to the number of api_types, "
|
|
91
|
+
"when api_templates is a list or tuple.")
|
|
92
|
+
|
|
93
|
+
self.wrapped_api_functions.clear()
|
|
94
|
+
index = 0
|
|
95
|
+
for framework, api_types in self.api_types.items():
|
|
96
|
+
wrapped_functions_in_framework = dict()
|
|
97
|
+
for api_type, api_modules in api_types.items():
|
|
98
|
+
wrapped_functions = dict()
|
|
99
|
+
name_prefix = Const.API_DATA_PREFIX.get(framework, {}).get(api_type, "API")
|
|
100
|
+
api_template = api_templates[index]
|
|
101
|
+
index += 1
|
|
102
|
+
for api_name in self.api_names.get(framework, {}).get(api_type, []):
|
|
103
|
+
ori_api = _get_attr(api_modules[0], api_name)
|
|
104
|
+
if callable(ori_api):
|
|
105
|
+
def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template):
|
|
106
|
+
def api_function(*args, **kwargs):
|
|
107
|
+
api_name_with_prefix = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1])
|
|
108
|
+
enable_wrap, args, kwargs = self.deal_with_self_kwargs(api_name_with_prefix,
|
|
109
|
+
api_func, args, kwargs)
|
|
110
|
+
if not enable_wrap:
|
|
111
|
+
logger.warning(f'Cannot collect precision data of {api_name_with_prefix}. '
|
|
112
|
+
'It may be fixed by passing the value of "self" '
|
|
113
|
+
'as a positional argument instead of a keyword argument. ')
|
|
114
|
+
return api_func(*args, **kwargs)
|
|
115
|
+
return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs)
|
|
116
|
+
api_function.__name__ = api_name
|
|
117
|
+
return api_function
|
|
118
|
+
wrapped_functions[api_name] = wrap_api_func(api_name, ori_api, name_prefix,
|
|
119
|
+
hook_build_func, api_template)
|
|
120
|
+
wrapped_functions_in_framework[api_type] = wrapped_functions
|
|
121
|
+
self.wrapped_api_functions[framework] = wrapped_functions_in_framework
|
|
122
|
+
return self.wrapped_api_functions
|
|
123
|
+
|
|
124
|
+
def _get_api_names(self):
|
|
125
|
+
api_names = dict()
|
|
126
|
+
|
|
127
|
+
for index, framework in enumerate(self.api_types.keys()):
|
|
128
|
+
api_list = load_yaml(self.api_list_paths[index])
|
|
129
|
+
valid_names = dict()
|
|
130
|
+
for api_type, api_modules in self.api_types.get(framework, {}).items():
|
|
131
|
+
key_in_file = Const.SUPPORT_API_DICT_KEY_MAP.get(framework, {}).get(api_type)
|
|
132
|
+
api_from_file = api_list.get(key_in_file, [])
|
|
133
|
+
names = set()
|
|
134
|
+
for api_name in api_from_file:
|
|
135
|
+
if f'{key_in_file}.{api_name}' in self.backlist:
|
|
136
|
+
continue
|
|
137
|
+
target_attr = api_name
|
|
138
|
+
target_module = api_modules[0]
|
|
139
|
+
if Const.SEP in api_name:
|
|
140
|
+
sub_module_name, target_attr = api_name.rsplit(Const.SEP, 1)
|
|
141
|
+
target_module = getattr(api_modules[0], sub_module_name, None)
|
|
142
|
+
if target_module and target_attr in dir(target_module):
|
|
143
|
+
names.add(api_name)
|
|
144
|
+
valid_names[api_type] = names
|
|
145
|
+
api_names[framework] = valid_names
|
|
146
|
+
|
|
147
|
+
return api_names
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class ApiRegistry:
|
|
151
|
+
"""
|
|
152
|
+
Base class for api registry.
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates, backlist=None):
|
|
156
|
+
self.ori_api_attr = dict()
|
|
157
|
+
self.wrapped_api_attr = dict()
|
|
158
|
+
self.inner_used_ori_attr = dict()
|
|
159
|
+
self.inner_used_wrapped_attr = dict()
|
|
160
|
+
self.api_types = api_types
|
|
161
|
+
self.inner_used_api = inner_used_api
|
|
162
|
+
self.supported_api_list_path = supported_api_list_path
|
|
163
|
+
self.api_templates = api_templates
|
|
164
|
+
self.backlist = backlist if backlist else []
|
|
165
|
+
self.all_api_registered = False
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def store_ori_attr(ori_api_group, api_list, api_ori_attr):
|
|
169
|
+
for api in api_list:
|
|
170
|
+
api_ori_attr[api] = _get_attr(ori_api_group, api)
|
|
171
|
+
|
|
172
|
+
@staticmethod
|
|
173
|
+
def set_api_attr(api_group, attr_dict):
|
|
174
|
+
for api, api_attr in attr_dict.items():
|
|
175
|
+
if Const.SEP in api:
|
|
176
|
+
sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
|
|
177
|
+
sub_module = getattr(api_group, sub_module_name, None)
|
|
178
|
+
if sub_module is not None:
|
|
179
|
+
setattr(sub_module, sub_op, api_attr)
|
|
180
|
+
else:
|
|
181
|
+
setattr(api_group, api, api_attr)
|
|
182
|
+
|
|
183
|
+
@staticmethod
|
|
184
|
+
def register_custom_api(module, api_name, api_prefix, hook_build_func, api_template):
|
|
185
|
+
def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template):
|
|
186
|
+
def api_function(*args, **kwargs):
|
|
187
|
+
return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs)
|
|
188
|
+
|
|
189
|
+
api_function.__name__ = api_name
|
|
190
|
+
return api_function
|
|
191
|
+
|
|
192
|
+
setattr(module, api_name,
|
|
193
|
+
wrap_api_func(api_name, getattr(module, api_name), api_prefix, hook_build_func, api_template))
|
|
194
|
+
|
|
195
|
+
def register_all_api(self):
|
|
196
|
+
self.all_api_registered = True
|
|
197
|
+
for framework, api_types in self.api_types.items():
|
|
198
|
+
for api_type, api_modules in api_types.items():
|
|
199
|
+
api_type_with_framework = framework + Const.SEP + api_type
|
|
200
|
+
for module in api_modules[1]:
|
|
201
|
+
self.set_api_attr(module, self.wrapped_api_attr.get(api_type_with_framework, {}))
|
|
202
|
+
|
|
203
|
+
def register_inner_used_api(self):
|
|
204
|
+
for api_type in self.inner_used_api.keys():
|
|
205
|
+
self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_wrapped_attr.get(api_type, {}))
|
|
206
|
+
|
|
207
|
+
def restore_all_api(self):
|
|
208
|
+
self.all_api_registered = False
|
|
209
|
+
for framework, api_types in self.api_types.items():
|
|
210
|
+
for api_type, api_modules in api_types.items():
|
|
211
|
+
api_type_with_framework = framework + Const.SEP + api_type
|
|
212
|
+
for module in api_modules[1]:
|
|
213
|
+
self.set_api_attr(module, self.ori_api_attr.get(api_type_with_framework, {}))
|
|
214
|
+
|
|
215
|
+
def restore_inner_used_api(self):
|
|
216
|
+
for api_type in self.inner_used_api.keys():
|
|
217
|
+
self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_ori_attr.get(api_type, {}))
|
|
218
|
+
|
|
219
|
+
def initialize_hook(self, hook_build_func):
|
|
220
|
+
api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path, self.backlist)
|
|
221
|
+
wrapped_api_functions = api_wrapper.wrap_api(self.api_templates, hook_build_func)
|
|
222
|
+
|
|
223
|
+
for framework, api_types in self.api_types.items():
|
|
224
|
+
for api_type, api_modules in api_types.items():
|
|
225
|
+
ori_attr = dict()
|
|
226
|
+
self.store_ori_attr(api_modules[0], api_wrapper.api_names.get(framework).get(api_type), ori_attr)
|
|
227
|
+
api_type_with_framework = framework + Const.SEP + api_type
|
|
228
|
+
self.ori_api_attr[api_type_with_framework] = ori_attr
|
|
229
|
+
self.wrapped_api_attr[api_type_with_framework] = wrapped_api_functions.get(framework).get(api_type)
|
|
230
|
+
|
|
231
|
+
for inner_used_api_type, inner_used_api_list in self.inner_used_api.items():
|
|
232
|
+
ori_attr = dict()
|
|
233
|
+
wrapped_attr = dict()
|
|
234
|
+
for api_name in inner_used_api_list[1:]:
|
|
235
|
+
if self.ori_api_attr.get(inner_used_api_type, {}).get(api_name):
|
|
236
|
+
ori_attr[api_name] = self.ori_api_attr.get(inner_used_api_type).get(api_name)
|
|
237
|
+
wrapped_attr[api_name] = self.wrapped_api_attr.get(inner_used_api_type).get(api_name)
|
|
238
|
+
self.inner_used_ori_attr[inner_used_api_type] = ori_attr
|
|
239
|
+
self.inner_used_wrapped_attr[inner_used_api_type] = wrapped_attr
|
|
@@ -41,7 +41,7 @@ class DataCollector:
|
|
|
41
41
|
self.backward_module_names = {}
|
|
42
42
|
self.optimizer_status = ""
|
|
43
43
|
self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
|
|
44
|
-
atexit.register(self.
|
|
44
|
+
atexit.register(self.write_json_at_exit)
|
|
45
45
|
|
|
46
46
|
@property
|
|
47
47
|
def dump_data_dir(self):
|
|
@@ -78,6 +78,11 @@ class DataCollector:
|
|
|
78
78
|
def write_json(self):
|
|
79
79
|
self.data_writer.write_json()
|
|
80
80
|
|
|
81
|
+
def write_json_at_exit(self):
|
|
82
|
+
if self.config.async_dump and self.config.task == Const.TENSOR:
|
|
83
|
+
self.data_processor.dump_async_data()
|
|
84
|
+
self.data_writer.write_json()
|
|
85
|
+
|
|
81
86
|
def update_data(self, name, data_info):
|
|
82
87
|
msg = f"msprobe is collecting data on {name}."
|
|
83
88
|
if self.config.task == Const.OVERFLOW_CHECK:
|
|
@@ -89,6 +94,10 @@ class DataCollector:
|
|
|
89
94
|
logger.debug(msg)
|
|
90
95
|
self.data_writer.update_data(data_info)
|
|
91
96
|
|
|
97
|
+
def call_stack_collect(self, name):
|
|
98
|
+
stack_info = self.data_processor.analyze_api_call_stack(name)
|
|
99
|
+
self.data_writer.update_stack(name, stack_info)
|
|
100
|
+
|
|
92
101
|
def forward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
93
102
|
if self.config.task == Const.FREE_BENCHMARK:
|
|
94
103
|
backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
|
|
@@ -118,9 +127,16 @@ class DataCollector:
|
|
|
118
127
|
self.set_is_recomputable(data_info, is_recompute)
|
|
119
128
|
if self.config.level == Const.LEVEL_L2:
|
|
120
129
|
return
|
|
121
|
-
self.
|
|
130
|
+
self.call_stack_collect(name)
|
|
122
131
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
123
132
|
|
|
133
|
+
def forward_data_collect_only_tensor(self, name, module, pid, module_input_output):
|
|
134
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
135
|
+
return
|
|
136
|
+
|
|
137
|
+
self.data_processor.analyze_forward(name, module, module_input_output)
|
|
138
|
+
|
|
139
|
+
|
|
124
140
|
def forward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
125
141
|
self.update_construct(name)
|
|
126
142
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
@@ -130,9 +146,15 @@ class DataCollector:
|
|
|
130
146
|
if self.config.task != Const.STRUCTURE:
|
|
131
147
|
data_info = self.data_processor.analyze_forward(name, module, module_input_output)
|
|
132
148
|
self.set_is_recomputable(data_info, is_recompute)
|
|
133
|
-
self.
|
|
149
|
+
self.call_stack_collect(name)
|
|
134
150
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
135
151
|
|
|
152
|
+
def backward_data_collect_only_tensor(self, name, module, pid, module_input_output, is_recompute=None):
|
|
153
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
154
|
+
return
|
|
155
|
+
|
|
156
|
+
self.data_processor.analyze_backward(name, module, module_input_output)
|
|
157
|
+
|
|
136
158
|
def backward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
137
159
|
self.update_construct(name)
|
|
138
160
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
@@ -180,7 +202,10 @@ class DataCollector:
|
|
|
180
202
|
self.optimizer_status_first_start[self.optimizer_status] = False
|
|
181
203
|
self.data_writer.update_construct({name: self.optimizer_status})
|
|
182
204
|
else:
|
|
183
|
-
self.
|
|
205
|
+
if self.config.level == Const.LEVEL_MIX and \
|
|
206
|
+
not (name.startswith(Const.MODULE) or name.startswith(Const.CELL)):
|
|
207
|
+
self.data_writer.update_construct({name: self.module_processor.api_parent_node})
|
|
208
|
+
|
|
184
209
|
self.data_writer.update_construct(self.module_processor.module_node)
|
|
185
210
|
|
|
186
211
|
def handle_data(self, name, data_info, flush=False):
|
|
@@ -204,6 +229,7 @@ class DataCollector:
|
|
|
204
229
|
|
|
205
230
|
def params_data_collect(self, name, param_name, pid, data):
|
|
206
231
|
grad_name = name + Const.SEP + Const.PARAMS_GRAD
|
|
232
|
+
self.update_api_or_module_name(grad_name)
|
|
207
233
|
# 校验scope和pid,以及当前name是否有过反向计算
|
|
208
234
|
if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
|
|
209
235
|
# 如果没有反向计算,则需要清除之前占位写入的grad数据
|
|
@@ -213,18 +239,19 @@ class DataCollector:
|
|
|
213
239
|
data_info = self.data_processor.analyze_params(grad_name, param_name, data)
|
|
214
240
|
self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
|
|
215
241
|
|
|
216
|
-
def fill_stack_tensor_data(self):
|
|
217
|
-
self.data_writer.fill_stack_tensor_data()
|
|
218
242
|
|
|
219
243
|
def debug_data_collect_forward(self, variable, name_with_count):
|
|
220
244
|
|
|
221
245
|
data_info = self.data_processor.analyze_debug_forward(variable, name_with_count)
|
|
222
|
-
|
|
246
|
+
name_with_count_category = name_with_count + Const.SEP + Const.DEBUG
|
|
247
|
+
self.data_writer.update_debug({name_with_count_category: data_info})
|
|
223
248
|
|
|
224
249
|
def debug_data_collect_backward(self, variable, grad_name_with_count):
|
|
225
250
|
# prepare all None nested data structure
|
|
226
251
|
all_none_data_info = self.data_processor.analyze_element_to_all_none(variable)
|
|
227
|
-
|
|
252
|
+
grad_name_with_count_category = grad_name_with_count + Const.SEP + Const.DEBUG
|
|
253
|
+
self.data_writer.update_debug({grad_name_with_count_category: all_none_data_info})
|
|
228
254
|
|
|
229
255
|
# register tensor backward hook
|
|
230
|
-
self.data_processor.analyze_debug_backward(variable,
|
|
256
|
+
self.data_processor.analyze_debug_backward(variable, grad_name_with_count_category,
|
|
257
|
+
self.data_writer.cache_debug['data'])
|