mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
- mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +16 -21
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +185 -11
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +4 -0
- msprobe/core/common/utils.py +42 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +381 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +199 -69
- msprobe/core/data_dump/data_collector.py +100 -25
- msprobe/core/data_dump/data_processor/base.py +130 -28
- msprobe/core/data_dump/data_processor/factory.py +8 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
- msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
- msprobe/core/data_dump/json_writer.py +54 -8
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +121 -17
- msprobe/docs/02.config_introduction.md +18 -16
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +107 -58
- msprobe/docs/06.data_dump_MindSpore.md +95 -34
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +310 -220
- msprobe/docs/21.visualization_PyTorch.md +125 -35
- msprobe/docs/22.visualization_MindSpore.md +149 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +525 -0
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +11 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +68 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +28 -2
- msprobe/mindspore/debugger/precision_debugger.py +100 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +13 -3
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +347 -107
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -7
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +116 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +42 -17
- msprobe/pytorch/debugger/precision_debugger.py +56 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +98 -28
- msprobe/pytorch/monitor/csv2tb.py +164 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +543 -318
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +76 -56
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +84 -48
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +264 -115
- msprobe/visualization/builder/graph_builder.py +93 -10
- msprobe/visualization/builder/msprobe_adapter.py +30 -6
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +15 -19
- msprobe/visualization/graph/distributed_analyzer.py +395 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +100 -27
- msprobe/visualization/utils.py +24 -31
- mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
- msprobe/pytorch/functional/module_dump.py +0 -84
- msprobe/pytorch/module_processer.py +0 -150
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -15,17 +15,17 @@
|
|
|
15
15
|
|
|
16
16
|
import functools
|
|
17
17
|
import threading
|
|
18
|
+
from collections import defaultdict
|
|
18
19
|
|
|
19
20
|
import torch
|
|
20
21
|
import torch.nn as nn
|
|
21
22
|
import torch.utils.hooks as full_hooks
|
|
22
23
|
|
|
23
|
-
from msprobe.core.common.const import Const
|
|
24
24
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class HOOKModule(nn.Module):
|
|
28
|
-
module_count =
|
|
28
|
+
module_count = defaultdict(int)
|
|
29
29
|
inner_stop_hook = {}
|
|
30
30
|
|
|
31
31
|
def __init__(self, build_hook) -> None:
|
|
@@ -41,12 +41,7 @@ class HOOKModule(nn.Module):
|
|
|
41
41
|
if hasattr(self, "prefix_op_name_"):
|
|
42
42
|
self.prefix = self.prefix_op_name_
|
|
43
43
|
|
|
44
|
-
|
|
45
|
-
HOOKModule.module_count[self.prefix] = 1
|
|
46
|
-
self.prefix += '0' + Const.SEP
|
|
47
|
-
else:
|
|
48
|
-
HOOKModule.module_count[self.prefix] += 1
|
|
49
|
-
self.prefix = self.prefix + str(HOOKModule.module_count[self.prefix] - 1) + Const.SEP
|
|
44
|
+
self.forward_data_collected = False
|
|
50
45
|
forward_pre_hook, forward_hook, backward_hook, _ = build_hook(self.prefix)
|
|
51
46
|
if torch_version_above_or_equal_2:
|
|
52
47
|
self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
|
|
@@ -66,9 +61,17 @@ class HOOKModule(nn.Module):
|
|
|
66
61
|
HOOKModule.inner_stop_hook[self.current_thread] = False
|
|
67
62
|
return result
|
|
68
63
|
|
|
69
|
-
@
|
|
70
|
-
def reset_module_stats(
|
|
71
|
-
|
|
64
|
+
@staticmethod
|
|
65
|
+
def reset_module_stats():
|
|
66
|
+
HOOKModule.module_count = defaultdict(int)
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def add_module_count(name):
|
|
70
|
+
HOOKModule.module_count[name] += 1
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def get_module_count(name):
|
|
74
|
+
return HOOKModule.module_count[name]
|
|
72
75
|
|
|
73
76
|
def _call_func(self, *args, **kwargs):
|
|
74
77
|
full_backward_hooks, non_full_backward_hooks = [], []
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.pytorch.common.log import logger
|
|
19
|
+
|
|
20
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
21
|
+
if torch_version_above_or_equal_2:
|
|
22
|
+
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def register_optimizer_hook(data_collector):
|
|
26
|
+
def optimizer_pre_step_hook(optimizer, args, kwargs):
|
|
27
|
+
data_collector.optimizer_status = Const.OPTIMIZER
|
|
28
|
+
|
|
29
|
+
def optimizer_post_step_hook(optimizer, args, kwargs):
|
|
30
|
+
data_collector.optimizer_status = Const.END_PREFIX + Const.OPTIMIZER
|
|
31
|
+
|
|
32
|
+
def patch_clip_grad(func):
|
|
33
|
+
def wrapper(*args, **kwargs):
|
|
34
|
+
data_collector.optimizer_status = Const.CLIP_GRAD
|
|
35
|
+
func(*args, **kwargs)
|
|
36
|
+
data_collector.optimizer_status = Const.END_PREFIX + Const.CLIP_GRAD
|
|
37
|
+
|
|
38
|
+
return wrapper
|
|
39
|
+
|
|
40
|
+
if torch_version_above_or_equal_2:
|
|
41
|
+
register_optimizer_step_pre_hook(optimizer_pre_step_hook)
|
|
42
|
+
register_optimizer_step_post_hook(optimizer_post_step_hook)
|
|
43
|
+
else:
|
|
44
|
+
logger.info_on_rank_0("Pytorch version is below 2.0, cannot register optimizer hook.")
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
torch.nn.utils.clip_grad_norm_ = patch_clip_grad(torch.nn.utils.clip_grad_norm_)
|
|
48
|
+
torch.nn.utils.clip_grad_norm = patch_clip_grad(torch.nn.utils.clip_grad_norm)
|
|
49
|
+
torch.nn.utils.clip_grad_value_ = patch_clip_grad(torch.nn.utils.clip_grad_value_)
|
|
50
|
+
except Exception as e:
|
|
51
|
+
logger.info_on_rank_0("Cannot patch clip grad function. detail:%s" % str(e))
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
from megatron.core.optimizer import MegatronOptimizer
|
|
55
|
+
MegatronOptimizer.clip_grad_norm = patch_clip_grad(MegatronOptimizer.clip_grad_norm)
|
|
56
|
+
except ImportError:
|
|
57
|
+
pass
|
|
58
|
+
except Exception as e:
|
|
59
|
+
logger.info_on_rank_0("Cannot patch megatron clip grad function. detail:%s" % str(e))
|
|
@@ -138,6 +138,10 @@ functional:
|
|
|
138
138
|
- fold
|
|
139
139
|
- multi_head_attention_forward
|
|
140
140
|
- scaled_dot_product_attention
|
|
141
|
+
- lp_pool3d
|
|
142
|
+
- dropout1d
|
|
143
|
+
- mish
|
|
144
|
+
- huber_loss
|
|
141
145
|
|
|
142
146
|
tensor:
|
|
143
147
|
- __add__
|
|
@@ -172,6 +176,7 @@ tensor:
|
|
|
172
176
|
- __sub__
|
|
173
177
|
- __truediv__
|
|
174
178
|
- __xor__
|
|
179
|
+
- __pow__
|
|
175
180
|
- abs
|
|
176
181
|
- abs_
|
|
177
182
|
- absolute
|
|
@@ -557,6 +562,27 @@ tensor:
|
|
|
557
562
|
- view_as
|
|
558
563
|
- xlogy
|
|
559
564
|
- xlogy_
|
|
565
|
+
- split
|
|
566
|
+
- stft
|
|
567
|
+
- nan_to_num
|
|
568
|
+
- dsplit
|
|
569
|
+
- orgqr
|
|
570
|
+
- bitwise_left_shift_
|
|
571
|
+
- arctan2
|
|
572
|
+
- histogram
|
|
573
|
+
- q_zero_point
|
|
574
|
+
- adjoint
|
|
575
|
+
- ormqr
|
|
576
|
+
- bitwise_right_shift_
|
|
577
|
+
- nanquantile
|
|
578
|
+
- lu
|
|
579
|
+
- quantile
|
|
580
|
+
- arctan2_
|
|
581
|
+
- qr
|
|
582
|
+
- diagonal_scatter
|
|
583
|
+
- corrcoef
|
|
584
|
+
- vsplit
|
|
585
|
+
- aminmax
|
|
560
586
|
|
|
561
587
|
torch:
|
|
562
588
|
- linalg.norm
|
|
@@ -1131,6 +1157,14 @@ torch_npu:
|
|
|
1131
1157
|
- npu_lstm
|
|
1132
1158
|
- npu_apply_adam
|
|
1133
1159
|
- npu_apply_adam_w
|
|
1160
|
+
- npu_anti_quant
|
|
1161
|
+
- npu_grouped_matmu
|
|
1162
|
+
- npu_quant_scatter
|
|
1163
|
+
- npu_group_norm_silu
|
|
1164
|
+
- npu_format_cast
|
|
1165
|
+
- npu_moe_finalize_routing
|
|
1166
|
+
- npu_moe_gating_top_k_softmax
|
|
1167
|
+
- npu_trans_quant_param
|
|
1134
1168
|
|
|
1135
1169
|
aten:
|
|
1136
1170
|
- signbit
|
|
@@ -1877,4 +1911,5 @@ distributed:
|
|
|
1877
1911
|
- all_to_all_single
|
|
1878
1912
|
- all_to_all
|
|
1879
1913
|
- all_gather_into_tensor
|
|
1880
|
-
- reduce_scatter_tensor
|
|
1914
|
+
- reduce_scatter_tensor
|
|
1915
|
+
- batch_isend_irecv
|
|
@@ -21,7 +21,6 @@ from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
|
21
21
|
from msprobe.pytorch.common.utils import torch_device_guard
|
|
22
22
|
from msprobe.core.common.const import Const
|
|
23
23
|
from msprobe.core.common.file_utils import load_yaml
|
|
24
|
-
from msprobe.core.common.inplace_op_checker import InplaceOpChecker
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
@@ -49,17 +48,20 @@ class DistributedOPTemplate(HOOKModule):
|
|
|
49
48
|
self.op_name_ = op_name
|
|
50
49
|
self.prefix_op_name_ = "Distributed" + Const.SEP + str(op_name) + Const.SEP
|
|
51
50
|
super().__init__(build_hook)
|
|
52
|
-
if not self.stop_hook
|
|
53
|
-
self.
|
|
51
|
+
if not self.stop_hook:
|
|
52
|
+
self.op_is_distributed = True
|
|
54
53
|
|
|
55
54
|
@torch_device_guard
|
|
56
55
|
def forward(self, *args, **kwargs):
|
|
56
|
+
handle = distributed_func.get(self.op_name_)(*args, **kwargs)
|
|
57
57
|
if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]:
|
|
58
|
-
handle
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
58
|
+
if handle and hasattr(handle, 'wait'):
|
|
59
|
+
handle.wait()
|
|
60
|
+
if self.op_name_ == "batch_isend_irecv":
|
|
61
|
+
if isinstance(handle, list):
|
|
62
|
+
for req in handle:
|
|
63
|
+
req.wait()
|
|
64
|
+
return handle
|
|
63
65
|
|
|
64
66
|
|
|
65
67
|
def wrap_distributed_op(op_name, hook):
|
|
@@ -23,46 +23,6 @@ from msprobe.pytorch.common.log import logger
|
|
|
23
23
|
from msprobe.core.common.file_utils import load_yaml
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def remove_dropout():
|
|
27
|
-
if torch.__version__ > "1.8":
|
|
28
|
-
logger.info_on_rank_0("For precision comparison, the probability p in the dropout method is set to 0.")
|
|
29
|
-
import torch.nn.functional as F
|
|
30
|
-
from torch import _VF
|
|
31
|
-
from torch.overrides import has_torch_function_unary, handle_torch_function
|
|
32
|
-
|
|
33
|
-
def function_dropout(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
34
|
-
inplace: bool = False) -> torch.Tensor:
|
|
35
|
-
if has_torch_function_unary(input_tensor):
|
|
36
|
-
return handle_torch_function(
|
|
37
|
-
function_dropout, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
|
|
38
|
-
if p < 0.0 or p > 1.0:
|
|
39
|
-
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
40
|
-
return _VF.dropout_(input_tensor, 0., training) if inplace else _VF.dropout(input_tensor, 0., training)
|
|
41
|
-
|
|
42
|
-
def function_dropout2d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
43
|
-
inplace: bool = False) -> torch.Tensor:
|
|
44
|
-
if has_torch_function_unary(input_tensor):
|
|
45
|
-
return handle_torch_function(
|
|
46
|
-
function_dropout2d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
|
|
47
|
-
if p < 0.0 or p > 1.0:
|
|
48
|
-
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
49
|
-
return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
|
|
50
|
-
0., training)
|
|
51
|
-
|
|
52
|
-
def function_dropout3d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
53
|
-
inplace: bool = False) -> torch.Tensor:
|
|
54
|
-
if has_torch_function_unary(input_tensor):
|
|
55
|
-
return handle_torch_function(
|
|
56
|
-
function_dropout3d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
|
|
57
|
-
if p < 0.0 or p > 1.0:
|
|
58
|
-
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
59
|
-
return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
|
|
60
|
-
0., training)
|
|
61
|
-
|
|
62
|
-
F.dropout = function_dropout
|
|
63
|
-
F.dropout2d = function_dropout2d
|
|
64
|
-
F.dropout3d = function_dropout3d
|
|
65
|
-
|
|
66
26
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
67
27
|
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
|
|
68
28
|
|
|
@@ -19,7 +19,7 @@ import argparse
|
|
|
19
19
|
import ast
|
|
20
20
|
import heapq
|
|
21
21
|
|
|
22
|
-
from msprobe.
|
|
22
|
+
from msprobe.pytorch.common.log import logger
|
|
23
23
|
from msprobe.core.common.const import MonitorConst
|
|
24
24
|
from msprobe.core.common.file_utils import check_path_before_create, save_json, create_directory, remove_path, \
|
|
25
25
|
check_file_or_directory_path, load_json
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -12,21 +12,22 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
|
|
15
|
+
import itertools
|
|
16
16
|
import os
|
|
17
|
-
import sys
|
|
18
17
|
import statistics as st
|
|
18
|
+
import sys
|
|
19
19
|
from abc import ABC
|
|
20
|
+
from collections import defaultdict
|
|
20
21
|
from dataclasses import dataclass, field
|
|
21
22
|
from typing import List
|
|
22
|
-
from collections import defaultdict
|
|
23
23
|
|
|
24
24
|
import pandas as pd
|
|
25
|
+
import torch
|
|
25
26
|
from torch.utils.tensorboard import SummaryWriter
|
|
26
27
|
|
|
27
|
-
from msprobe.core.common.log import logger
|
|
28
|
-
from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv
|
|
29
28
|
from msprobe.core.common.const import FileCheckConst, MonitorConst
|
|
29
|
+
from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv
|
|
30
|
+
from msprobe.pytorch.common.log import logger
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
class ScanRule(ABC):
|
|
@@ -134,9 +135,9 @@ class AnomalyDataFactory(ABC):
|
|
|
134
135
|
raise ValueError("tag must be a tuple with length 2")
|
|
135
136
|
tag_name = tag[0]
|
|
136
137
|
param_name = tag_name.split('/')[0]
|
|
137
|
-
call_id = self.name2callid.get(
|
|
138
|
-
if MonitorConst.
|
|
139
|
-
vpp_stage = int(param_name.split(MonitorConst.
|
|
138
|
+
call_id = self.name2callid.get(tag_name, -1)
|
|
139
|
+
if MonitorConst.NAME_SEP in param_name:
|
|
140
|
+
vpp_stage = int(param_name.split(MonitorConst.NAME_SEP)[0])
|
|
140
141
|
else:
|
|
141
142
|
vpp_stage = 0
|
|
142
143
|
|
|
@@ -153,6 +154,24 @@ class AnomalyDataFactory(ABC):
|
|
|
153
154
|
)
|
|
154
155
|
|
|
155
156
|
|
|
157
|
+
class TrainStage:
|
|
158
|
+
DEFAULT_STAGE = -1
|
|
159
|
+
FORWARD_STAGE = 0
|
|
160
|
+
BACKWARD_STAGE = 1
|
|
161
|
+
OPTIMIZER_STAGE = 2
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
FORWARD_KEY = [MonitorConst.ACTV]
|
|
165
|
+
BACKWARD_KEY = [MonitorConst.ACTVGRAD, MonitorConst.PRE_GRAD,
|
|
166
|
+
MonitorConst.POST_GRAD, MonitorConst.ACC_GRAD]
|
|
167
|
+
OPTIMIZER_KEY = [MonitorConst.EXP_AVG, MonitorConst.EXP_AVG_SQ]
|
|
168
|
+
TRAIN_STAGE = {
|
|
169
|
+
**{key_: TrainStage.FORWARD_STAGE for key_ in FORWARD_KEY},
|
|
170
|
+
**{key_: TrainStage.BACKWARD_STAGE for key_ in BACKWARD_KEY},
|
|
171
|
+
**{key_: TrainStage.OPTIMIZER_STAGE for key_ in OPTIMIZER_KEY}
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
|
|
156
175
|
@dataclass(eq=True)
|
|
157
176
|
class GradAnomalyData:
|
|
158
177
|
rank: int = 0
|
|
@@ -166,25 +185,48 @@ class GradAnomalyData:
|
|
|
166
185
|
group_mates: list = field(default=None, compare=False)
|
|
167
186
|
|
|
168
187
|
def __lt__(self, other):
|
|
188
|
+
"""
|
|
189
|
+
自定义比较函数,用于确定 GradAnomalyData 实例之间的顺序。
|
|
190
|
+
比较规则为:
|
|
191
|
+
step 和 micro_step 值越小优先级越高;
|
|
192
|
+
vpp 和 pp 在前向阶段值越小优先级越高,在非前向阶段值越大优先级越高;
|
|
193
|
+
call_id 值越小优先级越高。
|
|
194
|
+
"""
|
|
169
195
|
if not isinstance(other, GradAnomalyData):
|
|
170
196
|
return NotImplemented
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
197
|
+
|
|
198
|
+
self_train_stage = self.get_train_stage(self.tag_name)
|
|
199
|
+
other_train_stage = self.get_train_stage(other.tag_name)
|
|
200
|
+
|
|
201
|
+
def vpp_pp_comparator(anomaly):
|
|
202
|
+
"""
|
|
203
|
+
Determine the priority rule for vpp and pp based on train stage
|
|
204
|
+
Forward stage prefers smaller vpp and pp
|
|
205
|
+
Other stages prefer larger vpp and pp
|
|
206
|
+
"""
|
|
207
|
+
if self_train_stage == TrainStage.FORWARD_STAGE:
|
|
208
|
+
return anomaly.vpp_stage, anomaly.pp_stage
|
|
209
|
+
else:
|
|
210
|
+
return -anomaly.vpp_stage, -anomaly.pp_stage
|
|
211
|
+
|
|
212
|
+
self_cmp = [self.step, self.micro_step, self_train_stage, *vpp_pp_comparator(self), self.call_id]
|
|
213
|
+
other_cmp = [other.step, other.micro_step, other_train_stage, *vpp_pp_comparator(other), other.call_id]
|
|
214
|
+
return self_cmp < other_cmp
|
|
182
215
|
|
|
183
216
|
def __le__(self, other):
|
|
184
217
|
if not isinstance(other, GradAnomalyData):
|
|
185
218
|
return NotImplemented
|
|
186
219
|
return self == other or self < other
|
|
187
220
|
|
|
221
|
+
@staticmethod
|
|
222
|
+
def get_train_stage(tag_name):
|
|
223
|
+
"""
|
|
224
|
+
:param tag_name: "0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq"
|
|
225
|
+
:return: int, if forward return 0; if backward return 1; if optimizer return 2
|
|
226
|
+
"""
|
|
227
|
+
key_ = tag_name.split("/")[-1]
|
|
228
|
+
return TRAIN_STAGE.get(key_, TrainStage.DEFAULT_STAGE)
|
|
229
|
+
|
|
188
230
|
def to_dict(self):
|
|
189
231
|
return self.__dict__
|
|
190
232
|
|
|
@@ -198,7 +240,6 @@ class WriterInput:
|
|
|
198
240
|
path: str
|
|
199
241
|
ad_rules: list
|
|
200
242
|
job_id: str
|
|
201
|
-
anomaly_inform: bool = False
|
|
202
243
|
anomaly_factory: AnomalyDataFactory = None
|
|
203
244
|
ndigits: int = 6
|
|
204
245
|
step_count_per_record: int = 1
|
|
@@ -209,7 +250,6 @@ class BaseWriterWithAD:
|
|
|
209
250
|
self.tag2scalars = {}
|
|
210
251
|
self.ad_rules = writer_input.ad_rules
|
|
211
252
|
self.job_id = writer_input.job_id
|
|
212
|
-
self.anomaly_inform = writer_input.anomaly_inform
|
|
213
253
|
self.anomaly_factory = writer_input.anomaly_factory
|
|
214
254
|
self.anomalies = []
|
|
215
255
|
self.ndigits = writer_input.ndigits
|
|
@@ -242,6 +282,27 @@ class BaseWriterWithAD:
|
|
|
242
282
|
if self.anomaly_factory:
|
|
243
283
|
self.anomalies.append(self.anomaly_factory.create(tag, exception_message, global_step))
|
|
244
284
|
|
|
285
|
+
def write_metrics(self, ops, metric_value, step, prefix=''):
|
|
286
|
+
if not metric_value:
|
|
287
|
+
return
|
|
288
|
+
tensors = []
|
|
289
|
+
tags = list(itertools.product(metric_value.keys(), ops))
|
|
290
|
+
for op2tensor in metric_value.values():
|
|
291
|
+
tensors.extend(op2tensor.values())
|
|
292
|
+
if not tensors:
|
|
293
|
+
return
|
|
294
|
+
|
|
295
|
+
n_slices = len(tensors) // MonitorConst.SLICE_SIZE
|
|
296
|
+
with torch.no_grad():
|
|
297
|
+
for i in range(n_slices + 1):
|
|
298
|
+
begin = i * MonitorConst.SLICE_SIZE
|
|
299
|
+
end = (i+1) * MonitorConst.SLICE_SIZE
|
|
300
|
+
if begin == len(tensors):
|
|
301
|
+
continue
|
|
302
|
+
metric_list = torch.stack(tensors[begin:end]).cpu()
|
|
303
|
+
for tag, metric in zip(tags[begin:end], metric_list):
|
|
304
|
+
self.add_scalar(tag, metric, step)
|
|
305
|
+
|
|
245
306
|
def _ad(self, scalar_value, history):
|
|
246
307
|
return AnomalyScanner.scan(self.ad_rules, history, cur=scalar_value)
|
|
247
308
|
|
|
@@ -291,7 +352,7 @@ class CSVWriterWithAD(BaseWriterWithAD):
|
|
|
291
352
|
"""
|
|
292
353
|
if len(self.context_dict) == 0:
|
|
293
354
|
return
|
|
294
|
-
|
|
355
|
+
|
|
295
356
|
ster_start, step_end = self.get_step_interval(step)
|
|
296
357
|
filepath = os.path.join(self.log_dir, f'{prefix}_{ster_start}-{step_end}.csv')
|
|
297
358
|
if not os.path.exists(filepath):
|
|
@@ -300,11 +361,11 @@ class CSVWriterWithAD(BaseWriterWithAD):
|
|
|
300
361
|
|
|
301
362
|
new_data = []
|
|
302
363
|
for name, metric_value in self.context_dict.items():
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
new_data = pd.DataFrame(new_data).round(self.ndigits)
|
|
364
|
+
new_line = name.split(MonitorConst.NAME_SEP) + metric_value
|
|
365
|
+
new_line.insert(2, step)
|
|
366
|
+
new_data.append(new_line)
|
|
367
|
+
|
|
368
|
+
new_data = pd.DataFrame(new_data).round(self.ndigits).fillna("nan")
|
|
308
369
|
write_df_to_csv(new_data, filepath, mode='a+', header=False)
|
|
309
370
|
self.context_dict = defaultdict(list)
|
|
310
371
|
|
|
@@ -317,6 +378,15 @@ class CSVWriterWithAD(BaseWriterWithAD):
|
|
|
317
378
|
name = tag[0].split('/')[0]
|
|
318
379
|
self.context_dict[name].append(scalar_value.item())
|
|
319
380
|
|
|
381
|
+
def write_metrics(self, ops, metric_value, step, prefix=''):
|
|
382
|
+
super().write_metrics(ops, metric_value, step, prefix='')
|
|
383
|
+
|
|
384
|
+
if prefix in [MonitorConst.ACTV, MonitorConst.ACTVGRAD]:
|
|
385
|
+
self.header = MonitorConst.CSV_HEADER_XY + ops
|
|
386
|
+
else:
|
|
387
|
+
self.header = MonitorConst.CSV_HEADER + ops
|
|
388
|
+
self.write_csv(prefix, step)
|
|
389
|
+
|
|
320
390
|
def close(self):
|
|
321
391
|
pass
|
|
322
392
|
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
import datetime
|
|
16
|
+
import os
|
|
17
|
+
import re
|
|
18
|
+
from multiprocessing import Process
|
|
19
|
+
|
|
20
|
+
import pytz
|
|
21
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
|
|
24
|
+
from msprobe.core.common.const import MonitorConst
|
|
25
|
+
from msprobe.core.common.file_utils import read_csv, create_directory, remove_path
|
|
26
|
+
from msprobe.core.common.utils import is_int
|
|
27
|
+
from msprobe.pytorch.common.log import logger
|
|
28
|
+
from msprobe.pytorch.monitor.utils import get_target_output_dir
|
|
29
|
+
|
|
30
|
+
all_data_type_list = ["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param"]
|
|
31
|
+
CSV_FILE_SUFFIX = r"_\d+-\d+\.csv"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def parse_step_line(line, ops):
|
|
35
|
+
vp_id = line["vpp_stage"]
|
|
36
|
+
module_name = line[MonitorConst.HEADER_NAME]
|
|
37
|
+
step = line["step"]
|
|
38
|
+
vpp_name = f"vp{vp_id}:{module_name}"
|
|
39
|
+
if 'micro_step' in line:
|
|
40
|
+
vpp_name = f'{vpp_name}{MonitorConst.NAME_SEP}micro{line["micro_step"]}'
|
|
41
|
+
ops_result = {}
|
|
42
|
+
for op in ops:
|
|
43
|
+
ops_result[op] = line[op]
|
|
44
|
+
return vpp_name, step, ops_result
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def parse_step_fn(filepath):
|
|
48
|
+
data = read_csv(filepath)
|
|
49
|
+
ops = [k for k in data.keys() if k in MonitorConst.OP_LIST]
|
|
50
|
+
parse_step_result = {}
|
|
51
|
+
|
|
52
|
+
for _, line in data.iterrows():
|
|
53
|
+
vpp_name, step, ops_result = parse_step_line(line, ops)
|
|
54
|
+
if vpp_name not in parse_step_result:
|
|
55
|
+
parse_step_result[vpp_name] = {}
|
|
56
|
+
if step in parse_step_result[vpp_name]:
|
|
57
|
+
raise Exception(f"duplicated step({step})")
|
|
58
|
+
parse_step_result[vpp_name][step] = ops_result
|
|
59
|
+
return parse_step_result
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def write_step(output_dirpath, parse_step_result, rank, data_type):
|
|
63
|
+
tb_output_path = os.path.join(output_dirpath, f"rank{rank}", data_type)
|
|
64
|
+
if os.path.exists(tb_output_path):
|
|
65
|
+
remove_path(tb_output_path)
|
|
66
|
+
logger.warning(f"existing path {tb_output_path} will be recovered")
|
|
67
|
+
writer = SummaryWriter(tb_output_path)
|
|
68
|
+
for vpp_name, step_data_dict in parse_step_result.items():
|
|
69
|
+
step_data_list = [(step, ops) for step, ops in step_data_dict.items()]
|
|
70
|
+
step_data_list.sort(key=lambda x: x[0])
|
|
71
|
+
for step_data in step_data_list:
|
|
72
|
+
step = step_data[0]
|
|
73
|
+
ops = step_data[1]
|
|
74
|
+
for op, value in ops.items():
|
|
75
|
+
tag = f"{vpp_name}/{op}"
|
|
76
|
+
writer.add_scalar(tag, value, step)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def update_dict(dict1, dict2):
|
|
80
|
+
for key, value in dict2.items():
|
|
81
|
+
if key in dict1:
|
|
82
|
+
if isinstance(dict1[key], dict) and isinstance(value, dict):
|
|
83
|
+
try:
|
|
84
|
+
update_dict(dict1[key], value)
|
|
85
|
+
except Exception as e:
|
|
86
|
+
raise Exception(f"Error updating nested dict failed at key '{key}': {e}") from e
|
|
87
|
+
else:
|
|
88
|
+
raise Exception(f"duplicate key: {key}")
|
|
89
|
+
else:
|
|
90
|
+
dict1[key] = value
|
|
91
|
+
return dict1
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def csv2tb_by_step_work(target_output_dirs, output_dirpath, data_type_list):
|
|
95
|
+
for directory in tqdm(target_output_dirs):
|
|
96
|
+
dirpath = directory["path"]
|
|
97
|
+
rank = directory["rank"]
|
|
98
|
+
for data_type in data_type_list:
|
|
99
|
+
all_step_result = {}
|
|
100
|
+
for filename in os.listdir(dirpath):
|
|
101
|
+
if not re.match(f"{data_type}{CSV_FILE_SUFFIX}", filename):
|
|
102
|
+
continue
|
|
103
|
+
filepath = os.path.join(dirpath, filename)
|
|
104
|
+
try:
|
|
105
|
+
parse_step_result = parse_step_fn(filepath)
|
|
106
|
+
except Exception as e:
|
|
107
|
+
logger.error(f"csv2tensorboard parse {filepath} failed \n {e}")
|
|
108
|
+
break
|
|
109
|
+
|
|
110
|
+
all_step_result = update_dict(all_step_result, parse_step_result)
|
|
111
|
+
if all_step_result:
|
|
112
|
+
write_step(output_dirpath, all_step_result, rank, data_type)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def check_process_num(process_num):
|
|
116
|
+
if not is_int(process_num) or process_num <= 0:
|
|
117
|
+
raise ValueError(f"process_num({process_num}) is not a positive integer")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def check_data_type_list(data_type_list):
|
|
121
|
+
if data_type_list is None:
|
|
122
|
+
logger.info(f"data_type_list is None, use defualt all_data_type_list: {all_data_type_list}")
|
|
123
|
+
return
|
|
124
|
+
if not isinstance(data_type_list, list):
|
|
125
|
+
raise ValueError(f"data_type_list({data_type_list}) is not a list")
|
|
126
|
+
for data_type in data_type_list:
|
|
127
|
+
if data_type not in all_data_type_list:
|
|
128
|
+
raise ValueError(f"data type({data_type}) is not supported, supported data type: {all_data_type_list}")
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def csv2tensorboard_by_step(
|
|
132
|
+
monitor_path,
|
|
133
|
+
time_start=None,
|
|
134
|
+
time_end=None,
|
|
135
|
+
process_num=1,
|
|
136
|
+
data_type_list=None,
|
|
137
|
+
output_dirpath=None
|
|
138
|
+
):
|
|
139
|
+
check_process_num(process_num)
|
|
140
|
+
check_data_type_list(data_type_list)
|
|
141
|
+
target_output_dirs = get_target_output_dir(monitor_path, time_start, time_end)
|
|
142
|
+
target_output_dirs = [{"rank": rank, "path": path} for rank, path in target_output_dirs.items()]
|
|
143
|
+
if output_dirpath is None:
|
|
144
|
+
local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
|
|
145
|
+
cur_time = datetime.datetime.now(local_tz).strftime("%b%d_%H-%M-%S")
|
|
146
|
+
output_dirpath = os.path.join(monitor_path, f"{cur_time}-csv2tensorboard_by_step")
|
|
147
|
+
create_directory(output_dirpath)
|
|
148
|
+
|
|
149
|
+
task_num = len(target_output_dirs)
|
|
150
|
+
task_num_per_pro = task_num // process_num
|
|
151
|
+
target_data_type = data_type_list if data_type_list else all_data_type_list
|
|
152
|
+
|
|
153
|
+
processes = []
|
|
154
|
+
for pro_id in range(process_num):
|
|
155
|
+
task_start_id = pro_id * task_num_per_pro
|
|
156
|
+
task_end_id = (pro_id + 1) * task_num_per_pro if pro_id != process_num - 1 else task_num
|
|
157
|
+
task_dirs = target_output_dirs[task_start_id: task_end_id]
|
|
158
|
+
|
|
159
|
+
p = Process(target=csv2tb_by_step_work, args=(task_dirs, output_dirpath, target_data_type))
|
|
160
|
+
processes.append(p)
|
|
161
|
+
p.start()
|
|
162
|
+
for p in processes:
|
|
163
|
+
p.join()
|
|
164
|
+
logger.info(f"output has been saved to: {output_dirpath}")
|