mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
msprobe/pytorch/monitor/utils.py
CHANGED
|
@@ -22,10 +22,10 @@ import re
|
|
|
22
22
|
|
|
23
23
|
import torch
|
|
24
24
|
|
|
25
|
-
from msprobe.core.common.const import MonitorConst
|
|
25
|
+
from msprobe.core.common.const import MonitorConst
|
|
26
26
|
from msprobe.pytorch.common.log import logger
|
|
27
27
|
from msprobe.core.common.utils import is_int
|
|
28
|
-
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
28
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
device = "cpu"
|
|
@@ -43,7 +43,6 @@ DIRECTORY_MAX_LENGTH = 4096
|
|
|
43
43
|
|
|
44
44
|
beijing_tz = timezone(timedelta(hours=8))
|
|
45
45
|
MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio"))
|
|
46
|
-
MVGradResult = namedtuple('MVGradResult', ("exp_avg", "exp_avg_sq", "update", "ratio", "grad"))
|
|
47
46
|
|
|
48
47
|
|
|
49
48
|
class MsgConst:
|
|
@@ -102,9 +101,23 @@ def validate_ops(ops):
|
|
|
102
101
|
default_op = MonitorConst.OP_LIST[0]
|
|
103
102
|
valid_ops.append(default_op)
|
|
104
103
|
logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used")
|
|
104
|
+
# 增加默认shape和dtype参数
|
|
105
|
+
if "shape" not in valid_ops:
|
|
106
|
+
valid_ops.append("shape")
|
|
107
|
+
if "dtype" not in valid_ops:
|
|
108
|
+
valid_ops.append("dtype")
|
|
105
109
|
return valid_ops
|
|
106
110
|
|
|
107
111
|
|
|
112
|
+
def validate_ndigits(ndigits):
|
|
113
|
+
if not ndigits:
|
|
114
|
+
return
|
|
115
|
+
if not is_int(ndigits) or ndigits <= 0:
|
|
116
|
+
raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.")
|
|
117
|
+
if ndigits > MonitorConst.MAX_NDIGITS:
|
|
118
|
+
raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.")
|
|
119
|
+
|
|
120
|
+
|
|
108
121
|
def validate_ranks(ranks):
|
|
109
122
|
if not isinstance(ranks, list):
|
|
110
123
|
raise TypeError("module_ranks should be a list")
|
|
@@ -190,7 +203,7 @@ def validate_alert(alert):
|
|
|
190
203
|
args = rule.get("args")
|
|
191
204
|
if args and isinstance(args, dict):
|
|
192
205
|
threshold = args.get("threshold")
|
|
193
|
-
if not isinstance(threshold, float) or threshold < 0:
|
|
206
|
+
if not isinstance(threshold, (float, int)) or threshold < 0:
|
|
194
207
|
raise TypeError('threshold must be float and not less than 0')
|
|
195
208
|
dump = alert.get('dump')
|
|
196
209
|
if dump and not isinstance(dump, bool):
|
|
@@ -206,9 +219,24 @@ def validate_step_count_per_record(step_count_per_record):
|
|
|
206
219
|
raise ValueError("step_count_per_record must smaller than 1e6")
|
|
207
220
|
|
|
208
221
|
|
|
222
|
+
def validate_dynamic_on(dynamic_on):
|
|
223
|
+
if not isinstance(dynamic_on, bool):
|
|
224
|
+
raise TypeError('dynamic_on should be a bool')
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def validate_monitor_mbs_grad(monitor_mbs_grad):
|
|
228
|
+
if not isinstance(monitor_mbs_grad, bool):
|
|
229
|
+
logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.')
|
|
230
|
+
return False
|
|
231
|
+
return monitor_mbs_grad
|
|
232
|
+
|
|
233
|
+
|
|
209
234
|
def validate_config(config):
|
|
210
235
|
config['ops'] = validate_ops(config.get('ops', []))
|
|
211
236
|
|
|
237
|
+
ndigits = config.get('ndigits')
|
|
238
|
+
validate_ndigits(ndigits)
|
|
239
|
+
|
|
212
240
|
eps = config.get('eps', 1e-8)
|
|
213
241
|
if not isinstance(eps, float):
|
|
214
242
|
raise TypeError("eps should be a float")
|
|
@@ -246,9 +274,22 @@ def validate_config(config):
|
|
|
246
274
|
step_count_per_record = config.get('step_count_per_record', 1)
|
|
247
275
|
validate_step_count_per_record(step_count_per_record)
|
|
248
276
|
|
|
277
|
+
config["start_step"] = validate_int_arg(config.get("start_step"), "start_step",
|
|
278
|
+
MonitorConst.DEFAULT_START_STEP, MonitorConst.DEFAULT_START_STEP)
|
|
279
|
+
config["collect_times"] = validate_int_arg(config.get("collect_times"), "collect_times",
|
|
280
|
+
MonitorConst.DEFAULT_MIN_COLLECT_TIMES,
|
|
281
|
+
MonitorConst.DEFAULT_MAX_COLLECT_TIMES)
|
|
282
|
+
config["step_interval"] = validate_int_arg(config.get("step_interval"), "step_interval",
|
|
283
|
+
MonitorConst.DEFAULT_STEP_INTERVAL, MonitorConst.DEFAULT_STEP_INTERVAL)
|
|
284
|
+
|
|
249
285
|
squash_name = config.get('squash_name', True)
|
|
250
286
|
validate_squash_name(squash_name)
|
|
251
287
|
|
|
288
|
+
config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False))
|
|
289
|
+
|
|
290
|
+
dynamic_on = config.get('dynamic_on', False)
|
|
291
|
+
validate_dynamic_on(dynamic_on)
|
|
292
|
+
|
|
252
293
|
if not targets:
|
|
253
294
|
if xy_distribution:
|
|
254
295
|
config["all_xy"] = True
|
|
@@ -257,6 +298,8 @@ def validate_config(config):
|
|
|
257
298
|
|
|
258
299
|
def time_str2time_digit(time_str):
|
|
259
300
|
time_format = '%b%d_%H-%M-%S'
|
|
301
|
+
if not isinstance(time_str, str):
|
|
302
|
+
raise TypeError(f"time_str:{time_str} should be a str")
|
|
260
303
|
try:
|
|
261
304
|
time_digit = datetime.strptime(time_str, time_format)
|
|
262
305
|
except Exception as e:
|
|
@@ -284,3 +327,40 @@ def get_target_output_dir(monitor_path, time_start, time_end):
|
|
|
284
327
|
if start_ok and end_ok:
|
|
285
328
|
result[rank] = os.path.join(monitor_path, dirname)
|
|
286
329
|
return result
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def chmod_tensorboard_dir(path):
|
|
333
|
+
"""
|
|
334
|
+
format配置为tensorboard时,需要补充文件权限设置
|
|
335
|
+
"""
|
|
336
|
+
try:
|
|
337
|
+
recursive_chmod(path)
|
|
338
|
+
except Exception as e:
|
|
339
|
+
logger.warning(f"chmod tensorboard dir wrong because {e}, not updated, please check!!!")
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def validate_set_monitor(grad_acc_steps, start_iteration):
|
|
343
|
+
"""
|
|
344
|
+
validate parameters of set_monitor.
|
|
345
|
+
"""
|
|
346
|
+
grad_acc_steps = validate_int_arg(grad_acc_steps, "grad_acc_steps",
|
|
347
|
+
MonitorConst.DEFAULT_GRAD_ACC_STEPS, MonitorConst.DEFAULT_GRAD_ACC_STEPS)
|
|
348
|
+
|
|
349
|
+
start_iteration = validate_int_arg(start_iteration, "start_iteration",
|
|
350
|
+
MonitorConst.DEFAULT_START_ITERATION, MonitorConst.DEFAULT_START_ITERATION)
|
|
351
|
+
return grad_acc_steps, start_iteration
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def validate_int_arg(value, name, minimum, default_value):
|
|
355
|
+
"""Validate int args, if any exception occurs, use the default value."""
|
|
356
|
+
if value is None:
|
|
357
|
+
return default_value
|
|
358
|
+
try:
|
|
359
|
+
if not is_int(value):
|
|
360
|
+
raise TypeError(f"{name} must be int")
|
|
361
|
+
if value < minimum:
|
|
362
|
+
raise ValueError(f"{name} must greater than {minimum}")
|
|
363
|
+
except Exception as e:
|
|
364
|
+
value = default_value
|
|
365
|
+
logger.warning(f"Validate {name} failed, {e}, replaced with default value {value}.")
|
|
366
|
+
return value
|
|
@@ -125,8 +125,6 @@ class Saver:
|
|
|
125
125
|
|
|
126
126
|
def write_summary_csv(self, test_result):
|
|
127
127
|
test_rows = []
|
|
128
|
-
if self.stack_info:
|
|
129
|
-
test_rows[0].append(self.COLUMN_STACK_INFO)
|
|
130
128
|
|
|
131
129
|
check_op_str_pattern_valid(test_result.api_name)
|
|
132
130
|
df_row = [test_result.api_name, test_result.is_fwd_success, test_result.is_bwd_success]
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import json
|
|
17
17
|
import os
|
|
18
18
|
import time
|
|
19
|
+
import multiprocessing
|
|
19
20
|
from multiprocessing import Pool
|
|
20
21
|
|
|
21
22
|
import torch
|
|
@@ -52,6 +53,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
52
53
|
return
|
|
53
54
|
if dump_path is None:
|
|
54
55
|
logger.error("Please set dump_path when dump_mode is config!")
|
|
56
|
+
raise DispatchException("Please set dump_path when dump_mode is config!")
|
|
55
57
|
check_file_or_directory_path(dump_path, True)
|
|
56
58
|
|
|
57
59
|
self.device_id = torch_npu._C._npu_getDevice()
|
|
@@ -85,6 +87,11 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
85
87
|
self.get_ops(yaml_path)
|
|
86
88
|
|
|
87
89
|
self.lock = None
|
|
90
|
+
max_process_num = max(int((multiprocessing.cpu_count() + 1) // Const.CPU_QUARTER), 1)
|
|
91
|
+
if process_num > max_process_num:
|
|
92
|
+
logger.error(f"process_num should be less than or equal to {max_process_num}, but got {process_num}!")
|
|
93
|
+
raise DispatchException(f'process_num should be less than or equal to {max_process_num}, '
|
|
94
|
+
f'but got {process_num}!')
|
|
88
95
|
if process_num > 0:
|
|
89
96
|
self.pool = Pool(process_num)
|
|
90
97
|
if debug:
|
|
@@ -115,6 +122,8 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
115
122
|
if len(json_line_data) == 0:
|
|
116
123
|
break
|
|
117
124
|
msg = json.loads(json_line_data)
|
|
125
|
+
if len(msg) < 2:
|
|
126
|
+
raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.")
|
|
118
127
|
self.all_summary[msg[0]] = msg[1]
|
|
119
128
|
fp_handle.close()
|
|
120
129
|
|
|
@@ -199,8 +208,10 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
199
208
|
dispatch_workflow(run_param, data_info)
|
|
200
209
|
else:
|
|
201
210
|
self.lock.acquire()
|
|
202
|
-
|
|
203
|
-
|
|
211
|
+
try:
|
|
212
|
+
self.all_summary.append([])
|
|
213
|
+
finally:
|
|
214
|
+
self.lock.release()
|
|
204
215
|
run_param.process_flag = True
|
|
205
216
|
if self.check_fun(func, run_param):
|
|
206
217
|
data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, None, npu_out_cpu, cpu_out,
|
|
@@ -19,6 +19,8 @@ import os
|
|
|
19
19
|
from datetime import datetime, timezone
|
|
20
20
|
|
|
21
21
|
import torch
|
|
22
|
+
from msprobe.core.common.const import Const
|
|
23
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
22
24
|
from msprobe.core.common.file_utils import FileOpen, save_npy, save_json
|
|
23
25
|
from msprobe.pytorch.common.log import logger
|
|
24
26
|
|
|
@@ -91,6 +93,7 @@ def support_basic_type(data):
|
|
|
91
93
|
return False
|
|
92
94
|
|
|
93
95
|
|
|
96
|
+
@recursion_depth_decorator("dump_data")
|
|
94
97
|
def dump_data(data, prefix, dump_path):
|
|
95
98
|
if isinstance(data, (tuple, list)) and data:
|
|
96
99
|
for i, item in enumerate(data):
|
|
@@ -107,8 +110,11 @@ def dump_data(data, prefix, dump_path):
|
|
|
107
110
|
def save_temp_summary(api_index, single_api_summary, path, lock):
|
|
108
111
|
summary_path = os.path.join(path, f'summary.json')
|
|
109
112
|
lock.acquire()
|
|
110
|
-
|
|
111
|
-
|
|
113
|
+
try:
|
|
114
|
+
data = [api_index, single_api_summary]
|
|
115
|
+
save_json(summary_path, data, mode='a')
|
|
116
|
+
finally:
|
|
117
|
+
lock.release()
|
|
112
118
|
|
|
113
119
|
|
|
114
120
|
def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo):
|
|
@@ -27,8 +27,10 @@ else:
|
|
|
27
27
|
pta_cpu_device = torch.device("cpu")
|
|
28
28
|
|
|
29
29
|
from msprobe.core.common.const import CompareConst
|
|
30
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
30
31
|
from msprobe.pytorch.common.log import logger
|
|
31
32
|
|
|
33
|
+
|
|
32
34
|
cpu_device = torch._C.device("cpu")
|
|
33
35
|
COLOR_RED = '\033[31m'
|
|
34
36
|
COLOR_GREEN = '\033[32m'
|
|
@@ -85,6 +87,7 @@ def get_callstack():
|
|
|
85
87
|
return callstack
|
|
86
88
|
|
|
87
89
|
|
|
90
|
+
@recursion_depth_decorator("data_to_cpu")
|
|
88
91
|
def data_to_cpu(data, deep, data_cpu):
|
|
89
92
|
global cpu_device
|
|
90
93
|
list_cpu = []
|
|
@@ -45,12 +45,7 @@ class InteractiveCli(cmd.Cmd):
|
|
|
45
45
|
|
|
46
46
|
@catch_exception
|
|
47
47
|
def default(self, line=""):
|
|
48
|
-
self.
|
|
49
|
-
return False
|
|
50
|
-
|
|
51
|
-
@catch_exception
|
|
52
|
-
def do_run(self, line=""):
|
|
53
|
-
self.util.execute_command(line)
|
|
48
|
+
self.stdout.write("Command invalid, Only support command start with cad/vc/dc/pk/cn/pt\n")
|
|
54
49
|
|
|
55
50
|
@catch_exception
|
|
56
51
|
def do_vc(self, line=""):
|
|
@@ -13,12 +13,12 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import hashlib
|
|
17
16
|
import os
|
|
18
17
|
import re
|
|
19
18
|
import subprocess
|
|
20
19
|
import sys
|
|
21
20
|
import time
|
|
21
|
+
import zlib
|
|
22
22
|
from collections import namedtuple
|
|
23
23
|
|
|
24
24
|
import numpy as np
|
|
@@ -114,11 +114,12 @@ class Util:
|
|
|
114
114
|
@staticmethod
|
|
115
115
|
def get_md5_for_numpy(obj):
|
|
116
116
|
np_bytes = obj.tobytes()
|
|
117
|
-
|
|
118
|
-
return
|
|
117
|
+
md5_crc = zlib.crc32(np_bytes)
|
|
118
|
+
return f"{md5_crc:08x}"
|
|
119
119
|
|
|
120
120
|
@staticmethod
|
|
121
121
|
def deal_with_dir_or_file_inconsistency(output_path):
|
|
122
|
+
logger.warning(f"Trying to delete {output_path}")
|
|
122
123
|
remove_path(output_path)
|
|
123
124
|
raise ParseException("Inconsistent directory structure or file.")
|
|
124
125
|
|
|
@@ -264,7 +265,7 @@ class Util:
|
|
|
264
265
|
match = re_pattern.match(name)
|
|
265
266
|
if not match:
|
|
266
267
|
continue
|
|
267
|
-
if extern_pattern != '' and re_pattern.match(extern_pattern) and not
|
|
268
|
+
if extern_pattern != '' and re_pattern.match(extern_pattern) and not name.startswith(extern_pattern):
|
|
268
269
|
continue
|
|
269
270
|
file_list[name] = gen_info_func(name, match, file["root"])
|
|
270
271
|
return file_list
|
msprobe/pytorch/pt_config.py
CHANGED
|
@@ -16,9 +16,9 @@
|
|
|
16
16
|
import os
|
|
17
17
|
import re
|
|
18
18
|
|
|
19
|
-
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.const import Const, FileCheckConst
|
|
20
20
|
from msprobe.core.common.exceptions import MsprobeException
|
|
21
|
-
from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path,
|
|
21
|
+
from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, FileChecker
|
|
22
22
|
from msprobe.core.common.log import logger
|
|
23
23
|
from msprobe.core.common.utils import is_int
|
|
24
24
|
from msprobe.core.common_config import BaseConfig, CommonConfig
|
|
@@ -42,6 +42,7 @@ class TensorConfig(BaseConfig):
|
|
|
42
42
|
self.tls_path = json_config.get("tls_path", "./")
|
|
43
43
|
self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
|
|
44
44
|
self.check_config()
|
|
45
|
+
self._check_summary_mode()
|
|
45
46
|
self._check_file_format()
|
|
46
47
|
if self.online_run_ut:
|
|
47
48
|
self._check_online_run_ut()
|
|
@@ -65,7 +66,10 @@ class TensorConfig(BaseConfig):
|
|
|
65
66
|
check_file_or_directory_path(self.tls_path, isdir=True)
|
|
66
67
|
check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
|
|
67
68
|
check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
|
|
68
|
-
|
|
69
|
+
check_file_or_directory_path(os.path.join(self.tls_path, "ca.crt"))
|
|
70
|
+
crl_path = os.path.join(self.tls_path, "crl.pem")
|
|
71
|
+
if os.path.exists(crl_path):
|
|
72
|
+
check_file_or_directory_path(crl_path)
|
|
69
73
|
|
|
70
74
|
if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
|
|
71
75
|
raise Exception(f"host: {self.host} is invalid.")
|
|
@@ -80,9 +84,8 @@ class StatisticsConfig(BaseConfig):
|
|
|
80
84
|
self.check_config()
|
|
81
85
|
self._check_summary_mode()
|
|
82
86
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
raise Exception("summary_mode is invalid")
|
|
87
|
+
self.tensor_list = json_config.get("tensor_list", [])
|
|
88
|
+
self._check_str_list_config(self.tensor_list, "tensor_list")
|
|
86
89
|
|
|
87
90
|
|
|
88
91
|
class OverflowCheckConfig(BaseConfig):
|
|
@@ -95,6 +98,8 @@ class OverflowCheckConfig(BaseConfig):
|
|
|
95
98
|
def check_overflow_config(self):
|
|
96
99
|
if self.overflow_nums is not None and not is_int(self.overflow_nums):
|
|
97
100
|
raise Exception("overflow_num is invalid")
|
|
101
|
+
if self.overflow_nums is not None and self.overflow_nums != -1 and self.overflow_nums <= 0:
|
|
102
|
+
raise Exception("overflow_nums should be -1 or positive integer")
|
|
98
103
|
if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]:
|
|
99
104
|
raise Exception("check_mode is invalid")
|
|
100
105
|
|
|
@@ -148,7 +153,7 @@ class FreeBenchmarkCheckConfig(BaseConfig):
|
|
|
148
153
|
self.pert_mode in PytorchFreeBenchmarkConst.CPU_MODE_LIST
|
|
149
154
|
):
|
|
150
155
|
msg = (
|
|
151
|
-
f"You
|
|
156
|
+
f"You need to and can only set fuzz_device as {DeviceType.CPU} "
|
|
152
157
|
f"when pert_mode in {PytorchFreeBenchmarkConst.CPU_MODE_LIST}"
|
|
153
158
|
)
|
|
154
159
|
logger.error_log_with_exp(
|
|
@@ -271,13 +276,13 @@ class RunUTConfig(BaseConfig):
|
|
|
271
276
|
|
|
272
277
|
@classmethod
|
|
273
278
|
def check_nfs_path_config(cls, nfs_path):
|
|
274
|
-
if nfs_path
|
|
275
|
-
|
|
279
|
+
if nfs_path:
|
|
280
|
+
FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
276
281
|
|
|
277
282
|
@classmethod
|
|
278
283
|
def check_tls_path_config(cls, tls_path):
|
|
279
|
-
if tls_path
|
|
280
|
-
|
|
284
|
+
if tls_path:
|
|
285
|
+
FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
281
286
|
|
|
282
287
|
def check_run_ut_config(self):
|
|
283
288
|
RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.core.common.utils import Const
|
|
17
|
+
from msprobe.core.service import BaseService
|
|
18
|
+
from msprobe.pytorch.attl_manager import ATTLManager
|
|
19
|
+
from msprobe.pytorch.common.log import logger
|
|
20
|
+
from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_above_or_equal_2
|
|
21
|
+
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
|
|
22
|
+
from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate
|
|
23
|
+
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
24
|
+
from msprobe.pytorch.hook_module.jit_script_wrapper import wrap_jit_script_func
|
|
25
|
+
from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager
|
|
26
|
+
from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
|
|
27
|
+
|
|
28
|
+
if torch_version_above_or_equal_2:
|
|
29
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class PytorchService(BaseService):
|
|
33
|
+
@property
|
|
34
|
+
def _get_framework_type(self):
|
|
35
|
+
return Const.PT_FRAMEWORK
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def _get_current_rank():
|
|
39
|
+
return get_rank_if_initialized()
|
|
40
|
+
|
|
41
|
+
def _init_specific_components(self):
|
|
42
|
+
self.logger = logger
|
|
43
|
+
self.api_register = get_api_register()
|
|
44
|
+
self.module_processor = ModuleProcesser(self.data_collector.scope)
|
|
45
|
+
self.attl_manager = ATTLManager(self.config)
|
|
46
|
+
self.hook_manager = PytorchHookManager(self.data_collector, self.config, self.attl_manager)
|
|
47
|
+
self.api_template = ApiTemplate
|
|
48
|
+
|
|
49
|
+
def _register_hook(self):
|
|
50
|
+
self.attl_manager.attl_init()
|
|
51
|
+
if self._is_mix_level:
|
|
52
|
+
register_optimizer_hook(self.data_collector)
|
|
53
|
+
|
|
54
|
+
def _register_api_hook(self):
|
|
55
|
+
super()._register_api_hook()
|
|
56
|
+
wrap_jit_script_func()
|
|
57
|
+
|
|
58
|
+
def _register_module_hook(self):
|
|
59
|
+
ModuleProcesser.enable_module_dump = True
|
|
60
|
+
self.module_processor.register_module_hook(self.model, self.build_hook)
|
|
61
|
+
self.logger.info(f"The module {self.config.task} hook function is successfully mounted to the model.")
|
|
62
|
+
|
|
63
|
+
def _run_ut_dispatch(self, status):
|
|
64
|
+
if torch_version_above_or_equal_2:
|
|
65
|
+
run_ut_dispatch(self.attl_manager.attl, status, self.config.online_run_ut_recompute)
|
|
66
|
+
|
|
67
|
+
def _reset_status(self):
|
|
68
|
+
super()._reset_status()
|
|
69
|
+
ModuleProcesser.reset_module_stats()
|
|
70
|
+
HOOKModule.reset_module_stats()
|
|
@@ -14,21 +14,23 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import re
|
|
17
|
+
from dataclasses import dataclass
|
|
17
18
|
|
|
18
19
|
from msprobe.core.common.const import Const
|
|
19
|
-
from msprobe.core.common.file_utils import load_json
|
|
20
|
+
from msprobe.core.common.file_utils import load_json, save_json
|
|
21
|
+
from msprobe.core.common.utils import load_stack_json
|
|
20
22
|
from msprobe.visualization.builder.msprobe_adapter import get_input_output
|
|
21
23
|
from msprobe.visualization.builder.msprobe_adapter import op_patterns
|
|
22
24
|
from msprobe.visualization.graph.graph import Graph
|
|
23
25
|
from msprobe.visualization.graph.node_op import NodeOp
|
|
24
|
-
from msprobe.visualization.utils import
|
|
26
|
+
from msprobe.visualization.utils import GraphConst
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
class GraphBuilder:
|
|
28
30
|
backward_pattern = re.compile(r"(\.backward\.)(\d+)$")
|
|
29
31
|
forward_pattern = re.compile(r"(\.forward\.)(\d+)$")
|
|
30
|
-
# 匹配以大写字母开头,后接任意字母,并以Template(
|
|
31
|
-
template_pattern = re.compile(r'\b[A-Z][a-zA-Z]*Template\(')
|
|
32
|
+
# 匹配以大写字母开头,后接任意字母,并以Template(结尾,或包含api_template(的字符串
|
|
33
|
+
template_pattern = re.compile(r'\b([A-Z][a-zA-Z]*Template|api_template)\(')
|
|
32
34
|
|
|
33
35
|
@staticmethod
|
|
34
36
|
def build(construct_path, data_path, stack_path, model_name='DefaultModel', complete_stack=False):
|
|
@@ -44,13 +46,14 @@ class GraphBuilder:
|
|
|
44
46
|
"""
|
|
45
47
|
construct_dict = load_json(construct_path)
|
|
46
48
|
dump_dict = load_json(data_path)
|
|
47
|
-
stack_dict =
|
|
49
|
+
stack_dict = load_stack_json(stack_path)
|
|
48
50
|
if not complete_stack:
|
|
49
51
|
GraphBuilder._simplify_stack(stack_dict)
|
|
50
52
|
data_dict = dump_dict.get(GraphConst.DATA_KEY, {})
|
|
51
53
|
graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
|
|
52
54
|
GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
|
|
53
55
|
GraphBuilder._collect_apis_between_modules(graph)
|
|
56
|
+
GraphBuilder._add_parameters_grad(graph, data_dict)
|
|
54
57
|
return graph
|
|
55
58
|
|
|
56
59
|
@staticmethod
|
|
@@ -60,10 +63,10 @@ class GraphBuilder:
|
|
|
60
63
|
"""
|
|
61
64
|
result = {}
|
|
62
65
|
if config.graph_b:
|
|
63
|
-
result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict()
|
|
64
|
-
result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict()
|
|
66
|
+
result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict(config.compare_mode)
|
|
67
|
+
result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict(config.compare_mode)
|
|
65
68
|
else:
|
|
66
|
-
result = config.graph_n.to_dict()
|
|
69
|
+
result = config.graph_n.to_dict(config.compare_mode)
|
|
67
70
|
if config.tool_tip:
|
|
68
71
|
result[GraphConst.JSON_TIP_KEY] = config.tool_tip
|
|
69
72
|
if config.node_colors:
|
|
@@ -73,7 +76,7 @@ class GraphBuilder:
|
|
|
73
76
|
if config.task:
|
|
74
77
|
result[GraphConst.JSON_TASK_KEY] = config.task
|
|
75
78
|
result[GraphConst.OVERFLOW_CHECK] = config.overflow_check
|
|
76
|
-
|
|
79
|
+
save_json(filename, result, indent=4)
|
|
77
80
|
|
|
78
81
|
@staticmethod
|
|
79
82
|
def _simplify_stack(stack_dict):
|
|
@@ -186,6 +189,8 @@ class GraphBuilder:
|
|
|
186
189
|
# 数据格式:"output": [[{param1}, {param2}, ...]]
|
|
187
190
|
if GraphBuilder._is_valid_batch_p2p_output(param_list):
|
|
188
191
|
for param in param_list[0]:
|
|
192
|
+
if not isinstance(param, dict):
|
|
193
|
+
continue
|
|
189
194
|
info = {GraphConst.OP: param.get(GraphConst.OP), GraphConst.PEER: param.get(GraphConst.PEER),
|
|
190
195
|
GraphConst.GROUP_ID: param.get(GraphConst.GROUP_ID)}
|
|
191
196
|
node.batch_p2p_info.append(info)
|
|
@@ -235,10 +240,46 @@ class GraphBuilder:
|
|
|
235
240
|
|
|
236
241
|
graph.root.subnodes = output
|
|
237
242
|
|
|
243
|
+
@staticmethod
|
|
244
|
+
def _add_parameters_grad(graph, data_dict):
|
|
245
|
+
"""
|
|
246
|
+
将parameters_grad信息添加到graph中,
|
|
247
|
+
对应模块的parameters_grad节点添加到对应模块的最后一次backward节点(backward计数最大)内作为子节点
|
|
248
|
+
|
|
249
|
+
例如,graph有节点Module.a.backward.0, Module.a.backward.1, Module.a.backward.2
|
|
250
|
+
则Module.a.parameters_grad添加在Module.a.backward.2内作为子节点
|
|
251
|
+
"""
|
|
252
|
+
prefixes = []
|
|
253
|
+
suffix = Const.SEP + Const.PARAMS_GRAD
|
|
254
|
+
for node_id in data_dict.keys():
|
|
255
|
+
if node_id not in graph.node_map and node_id.endswith(suffix):
|
|
256
|
+
prefixes.append(node_id.replace(suffix, ''))
|
|
257
|
+
|
|
258
|
+
max_info = {prefix: 0 for prefix in prefixes}
|
|
259
|
+
|
|
260
|
+
for key in graph.node_map.keys():
|
|
261
|
+
parts = key.split(Const.SEP)
|
|
262
|
+
if len(parts) > 2 and parts[-2] == Const.BACKWARD:
|
|
263
|
+
num = int(parts[-1])
|
|
264
|
+
prefix = Const.SEP.join(parts[:-2])
|
|
265
|
+
if prefix in max_info and num > max_info[prefix]:
|
|
266
|
+
max_info[prefix] = num
|
|
267
|
+
|
|
268
|
+
for prefix, num in max_info.items():
|
|
269
|
+
node_id = prefix + Const.SEP + Const.BACKWARD + Const.SEP + str(num)
|
|
270
|
+
node = graph.get_node(node_id)
|
|
271
|
+
if node:
|
|
272
|
+
parameters_grad_node_id = graph.add_node(NodeOp.module, prefix + suffix, up_node=node)
|
|
273
|
+
# 添加输入输出数据
|
|
274
|
+
node_data = data_dict.get(parameters_grad_node_id, {})
|
|
275
|
+
input_data, output_data = get_input_output(node_data, parameters_grad_node_id)
|
|
276
|
+
# 更新数据
|
|
277
|
+
graph.get_node(parameters_grad_node_id).set_input_output(input_data, output_data)
|
|
278
|
+
|
|
238
279
|
|
|
239
280
|
class GraphExportConfig:
|
|
240
281
|
def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='',
|
|
241
|
-
overflow_check=False):
|
|
282
|
+
overflow_check=False, compare_mode=None):
|
|
242
283
|
self.graph_n = graph_n
|
|
243
284
|
self.graph_b = graph_b
|
|
244
285
|
self.tool_tip = tool_tip
|
|
@@ -246,3 +287,21 @@ class GraphExportConfig:
|
|
|
246
287
|
self.micro_steps = micro_steps
|
|
247
288
|
self.task = task
|
|
248
289
|
self.overflow_check = overflow_check
|
|
290
|
+
self.compare_mode = compare_mode
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
@dataclass
|
|
294
|
+
class GraphInfo:
|
|
295
|
+
graph: Graph
|
|
296
|
+
construct_path: str
|
|
297
|
+
data_path: str
|
|
298
|
+
stack_path: str
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
@dataclass
|
|
302
|
+
class BuildGraphTaskInfo:
|
|
303
|
+
graph_info_n: GraphInfo
|
|
304
|
+
graph_info_b: GraphInfo
|
|
305
|
+
npu_rank: str
|
|
306
|
+
bench_rank: str
|
|
307
|
+
time_str: str
|