mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -89,12 +89,6 @@ class FuzzHandler(ABC):
|
|
|
89
89
|
)
|
|
90
90
|
return origin_output_chunks, perturbed_output_chunks
|
|
91
91
|
|
|
92
|
-
@staticmethod
|
|
93
|
-
def convert_overflow_ratio_to_consistent(ratio):
|
|
94
|
-
if math.isnan(ratio) or math.isinf(ratio):
|
|
95
|
-
return ThresholdConfig.COMP_CONSISTENT
|
|
96
|
-
return ratio
|
|
97
|
-
|
|
98
92
|
@abstractmethod
|
|
99
93
|
def get_threshold(self, dtype):
|
|
100
94
|
pass
|
|
@@ -107,10 +101,10 @@ class FuzzHandler(ABC):
|
|
|
107
101
|
self, origin_output, perturbed_output, norm_type, abs_tol
|
|
108
102
|
):
|
|
109
103
|
if norm_type == NormType.ENDLESS_NORM:
|
|
110
|
-
return self.
|
|
104
|
+
return self.calculate_max_ratio(origin_output, perturbed_output, abs_tol)
|
|
111
105
|
return ThresholdConfig.COMP_CONSISTENT
|
|
112
106
|
|
|
113
|
-
def
|
|
107
|
+
def calculate_max_ratio(self, origin_output, perturbed_output, abs_tol):
|
|
114
108
|
origin_output_chunks, perturbed_output_chunks = (
|
|
115
109
|
self.tensor_split_for_error_calculate(origin_output, perturbed_output)
|
|
116
110
|
)
|
|
@@ -122,42 +116,30 @@ class FuzzHandler(ABC):
|
|
|
122
116
|
raise FreeBenchmarkException(
|
|
123
117
|
FreeBenchmarkException.OutputIndexError, err_msg
|
|
124
118
|
)
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
norm3 = np.inf
|
|
119
|
+
|
|
120
|
+
max_ratio = ThresholdConfig.COMP_CONSISTENT
|
|
128
121
|
for i, chunk_origin in enumerate(origin_output_chunks):
|
|
129
122
|
if chunk_origin.nelement() == 0:
|
|
130
123
|
break
|
|
131
124
|
chunk_perturbed = perturbed_output_chunks[i]
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
TorchC.
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
TorchC.div(
|
|
143
|
-
|
|
144
|
-
TorchC.clamp(chunk_origin, min=abs_tol),
|
|
145
|
-
),
|
|
146
|
-
1,
|
|
125
|
+
# 如果乘积最小值 < 极小值乘积的负值,认为存在非极小值符号相反的情况
|
|
126
|
+
if TorchC.lt(
|
|
127
|
+
TorchC.min(TorchC.mul(chunk_origin, chunk_perturbed)), -(abs_tol**2)
|
|
128
|
+
):
|
|
129
|
+
return ThresholdConfig.SYMBOL_FLIPPING
|
|
130
|
+
# 求A/B B/A的比值前,将值限制在大于极小值范围内
|
|
131
|
+
clamp_origin = TorchC.clamp(TorchC.abs(chunk_origin), min=abs_tol)
|
|
132
|
+
clamp_perturbed = TorchC.clamp(TorchC.abs(chunk_perturbed), min=abs_tol)
|
|
133
|
+
# 对于计算结果为nan的情况,认为两者没有差异
|
|
134
|
+
ratio_tensor = TorchC.nan_to_num(
|
|
135
|
+
TorchC.div(clamp_origin, clamp_perturbed),
|
|
136
|
+
nan=ThresholdConfig.COMP_CONSISTENT,
|
|
147
137
|
)
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
|
|
154
|
-
norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1))
|
|
155
|
-
|
|
156
|
-
if norm3 < 0:
|
|
157
|
-
ratio = ThresholdConfig.SYMBOL_FLIPPING
|
|
158
|
-
else:
|
|
159
|
-
ratio = max(norm1, norm2)
|
|
160
|
-
return ratio
|
|
138
|
+
# 求A/B 和 B/A比值最大值,其中 B/A的最大值为 A/B的最小值的倒数
|
|
139
|
+
min_ratio, max_ratio = TorchC.stack([*TorchC.aminmax(ratio_tensor)]).tolist()
|
|
140
|
+
min_ratio_reciprocal = np.inf if min_ratio == 0 else 1 / min_ratio
|
|
141
|
+
max_ratio = max(max_ratio, min_ratio_reciprocal)
|
|
142
|
+
return max_ratio
|
|
161
143
|
|
|
162
144
|
def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float:
|
|
163
145
|
try:
|
|
@@ -220,10 +202,12 @@ class FuzzHandler(ABC):
|
|
|
220
202
|
)
|
|
221
203
|
npu_consistent = is_consistent
|
|
222
204
|
max_fuzz_ratio = (
|
|
223
|
-
max_fuzz_ratio
|
|
205
|
+
max_fuzz_ratio
|
|
206
|
+
if not isinstance(ratio, (int, float))
|
|
207
|
+
else max(max_fuzz_ratio, ratio)
|
|
224
208
|
)
|
|
225
|
-
data_params.is_consistent = is_consistent
|
|
226
|
-
if not is_consistent
|
|
209
|
+
data_params.is_consistent = is_consistent
|
|
210
|
+
if not is_consistent:
|
|
227
211
|
self.unequal_rows.append(
|
|
228
212
|
make_unequal_row(data_params, self.params, ratio=ratio)
|
|
229
213
|
)
|
|
@@ -235,12 +219,12 @@ class FuzzHandler(ABC):
|
|
|
235
219
|
)
|
|
236
220
|
npu_consistent = npu_consistent and is_consistent
|
|
237
221
|
max_fuzz_ratio = (
|
|
238
|
-
max_fuzz_ratio
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
is_consistent and data_params.is_consistent
|
|
222
|
+
max_fuzz_ratio
|
|
223
|
+
if not isinstance(ratio, (int, float))
|
|
224
|
+
else max(max_fuzz_ratio, ratio)
|
|
242
225
|
)
|
|
243
|
-
|
|
226
|
+
data_params.is_consistent = is_consistent
|
|
227
|
+
if not is_consistent:
|
|
244
228
|
self.unequal_rows.append(
|
|
245
229
|
make_unequal_row(
|
|
246
230
|
data_params, self.params, ratio=ratio, index=index_
|
|
@@ -75,10 +75,6 @@ class PreheatHandler(FuzzHandler):
|
|
|
75
75
|
if self.params.preheat_config.get("preheat_step") <= self.params.step:
|
|
76
76
|
return data_params.original_result
|
|
77
77
|
|
|
78
|
-
if not data_params.grad_unequal_flag:
|
|
79
|
-
data_params.grad_unequal_flag = True
|
|
80
|
-
data_params.is_consistent = False
|
|
81
|
-
return data_params.original_result
|
|
82
78
|
preheat_counter.add_api_called_time(self.pure_name)
|
|
83
79
|
|
|
84
80
|
if not self._is_take_a_sample():
|
|
@@ -15,17 +15,17 @@
|
|
|
15
15
|
|
|
16
16
|
import functools
|
|
17
17
|
import threading
|
|
18
|
+
from collections import defaultdict
|
|
18
19
|
|
|
19
20
|
import torch
|
|
20
21
|
import torch.nn as nn
|
|
21
22
|
import torch.utils.hooks as full_hooks
|
|
22
23
|
|
|
23
|
-
from msprobe.core.common.const import Const
|
|
24
24
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class HOOKModule(nn.Module):
|
|
28
|
-
module_count =
|
|
28
|
+
module_count = defaultdict(int)
|
|
29
29
|
inner_stop_hook = {}
|
|
30
30
|
|
|
31
31
|
def __init__(self, build_hook) -> None:
|
|
@@ -41,12 +41,7 @@ class HOOKModule(nn.Module):
|
|
|
41
41
|
if hasattr(self, "prefix_op_name_"):
|
|
42
42
|
self.prefix = self.prefix_op_name_
|
|
43
43
|
|
|
44
|
-
|
|
45
|
-
HOOKModule.module_count[self.prefix] = 1
|
|
46
|
-
self.prefix += '0' + Const.SEP
|
|
47
|
-
else:
|
|
48
|
-
HOOKModule.module_count[self.prefix] += 1
|
|
49
|
-
self.prefix = self.prefix + str(HOOKModule.module_count[self.prefix] - 1) + Const.SEP
|
|
44
|
+
self.forward_data_collected = False
|
|
50
45
|
forward_pre_hook, forward_hook, backward_hook, _ = build_hook(self.prefix)
|
|
51
46
|
if torch_version_above_or_equal_2:
|
|
52
47
|
self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
|
|
@@ -66,9 +61,17 @@ class HOOKModule(nn.Module):
|
|
|
66
61
|
HOOKModule.inner_stop_hook[self.current_thread] = False
|
|
67
62
|
return result
|
|
68
63
|
|
|
69
|
-
@
|
|
70
|
-
def reset_module_stats(
|
|
71
|
-
|
|
64
|
+
@staticmethod
|
|
65
|
+
def reset_module_stats():
|
|
66
|
+
HOOKModule.module_count = defaultdict(int)
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def add_module_count(name):
|
|
70
|
+
HOOKModule.module_count[name] += 1
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def get_module_count(name):
|
|
74
|
+
return HOOKModule.module_count[name]
|
|
72
75
|
|
|
73
76
|
def _call_func(self, *args, **kwargs):
|
|
74
77
|
full_backward_hooks, non_full_backward_hooks = [], []
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.pytorch.common.log import logger
|
|
19
|
+
|
|
20
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
21
|
+
if torch_version_above_or_equal_2:
|
|
22
|
+
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def register_optimizer_hook(data_collector):
|
|
26
|
+
def optimizer_pre_step_hook(optimizer, args, kwargs):
|
|
27
|
+
data_collector.optimizer_status = Const.OPTIMIZER
|
|
28
|
+
|
|
29
|
+
def optimizer_post_step_hook(optimizer, args, kwargs):
|
|
30
|
+
data_collector.optimizer_status = Const.END_PREFIX + Const.OPTIMIZER
|
|
31
|
+
|
|
32
|
+
def patch_clip_grad(func):
|
|
33
|
+
def wrapper(*args, **kwargs):
|
|
34
|
+
data_collector.optimizer_status = Const.CLIP_GRAD
|
|
35
|
+
func(*args, **kwargs)
|
|
36
|
+
data_collector.optimizer_status = Const.END_PREFIX + Const.CLIP_GRAD
|
|
37
|
+
|
|
38
|
+
return wrapper
|
|
39
|
+
|
|
40
|
+
if torch_version_above_or_equal_2:
|
|
41
|
+
register_optimizer_step_pre_hook(optimizer_pre_step_hook)
|
|
42
|
+
register_optimizer_step_post_hook(optimizer_post_step_hook)
|
|
43
|
+
else:
|
|
44
|
+
logger.info_on_rank_0("Pytorch version is below 2.0, cannot register optimizer hook.")
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
torch.nn.utils.clip_grad_norm_ = patch_clip_grad(torch.nn.utils.clip_grad_norm_)
|
|
48
|
+
torch.nn.utils.clip_grad_norm = patch_clip_grad(torch.nn.utils.clip_grad_norm)
|
|
49
|
+
torch.nn.utils.clip_grad_value_ = patch_clip_grad(torch.nn.utils.clip_grad_value_)
|
|
50
|
+
except Exception as e:
|
|
51
|
+
logger.info_on_rank_0("Cannot patch clip grad function. detail:%s" % str(e))
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
from megatron.core.optimizer import MegatronOptimizer
|
|
55
|
+
MegatronOptimizer.clip_grad_norm = patch_clip_grad(MegatronOptimizer.clip_grad_norm)
|
|
56
|
+
except ImportError:
|
|
57
|
+
pass
|
|
58
|
+
except Exception as e:
|
|
59
|
+
logger.info_on_rank_0("Cannot patch megatron clip grad function. detail:%s" % str(e))
|
|
@@ -138,6 +138,10 @@ functional:
|
|
|
138
138
|
- fold
|
|
139
139
|
- multi_head_attention_forward
|
|
140
140
|
- scaled_dot_product_attention
|
|
141
|
+
- lp_pool3d
|
|
142
|
+
- dropout1d
|
|
143
|
+
- mish
|
|
144
|
+
- huber_loss
|
|
141
145
|
|
|
142
146
|
tensor:
|
|
143
147
|
- __add__
|
|
@@ -172,6 +176,7 @@ tensor:
|
|
|
172
176
|
- __sub__
|
|
173
177
|
- __truediv__
|
|
174
178
|
- __xor__
|
|
179
|
+
- __pow__
|
|
175
180
|
- abs
|
|
176
181
|
- abs_
|
|
177
182
|
- absolute
|
|
@@ -557,6 +562,27 @@ tensor:
|
|
|
557
562
|
- view_as
|
|
558
563
|
- xlogy
|
|
559
564
|
- xlogy_
|
|
565
|
+
- split
|
|
566
|
+
- stft
|
|
567
|
+
- nan_to_num
|
|
568
|
+
- dsplit
|
|
569
|
+
- orgqr
|
|
570
|
+
- bitwise_left_shift_
|
|
571
|
+
- arctan2
|
|
572
|
+
- histogram
|
|
573
|
+
- q_zero_point
|
|
574
|
+
- adjoint
|
|
575
|
+
- ormqr
|
|
576
|
+
- bitwise_right_shift_
|
|
577
|
+
- nanquantile
|
|
578
|
+
- lu
|
|
579
|
+
- quantile
|
|
580
|
+
- arctan2_
|
|
581
|
+
- qr
|
|
582
|
+
- diagonal_scatter
|
|
583
|
+
- corrcoef
|
|
584
|
+
- vsplit
|
|
585
|
+
- aminmax
|
|
560
586
|
|
|
561
587
|
torch:
|
|
562
588
|
- linalg.norm
|
|
@@ -1131,6 +1157,14 @@ torch_npu:
|
|
|
1131
1157
|
- npu_lstm
|
|
1132
1158
|
- npu_apply_adam
|
|
1133
1159
|
- npu_apply_adam_w
|
|
1160
|
+
- npu_anti_quant
|
|
1161
|
+
- npu_grouped_matmu
|
|
1162
|
+
- npu_quant_scatter
|
|
1163
|
+
- npu_group_norm_silu
|
|
1164
|
+
- npu_format_cast
|
|
1165
|
+
- npu_moe_finalize_routing
|
|
1166
|
+
- npu_moe_gating_top_k_softmax
|
|
1167
|
+
- npu_trans_quant_param
|
|
1134
1168
|
|
|
1135
1169
|
aten:
|
|
1136
1170
|
- signbit
|
|
@@ -21,7 +21,6 @@ from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
|
21
21
|
from msprobe.pytorch.common.utils import torch_device_guard
|
|
22
22
|
from msprobe.core.common.const import Const
|
|
23
23
|
from msprobe.core.common.file_utils import load_yaml
|
|
24
|
-
from msprobe.core.common.inplace_op_checker import InplaceOpChecker
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
@@ -49,17 +48,16 @@ class DistributedOPTemplate(HOOKModule):
|
|
|
49
48
|
self.op_name_ = op_name
|
|
50
49
|
self.prefix_op_name_ = "Distributed" + Const.SEP + str(op_name) + Const.SEP
|
|
51
50
|
super().__init__(build_hook)
|
|
52
|
-
if not self.stop_hook
|
|
53
|
-
self.
|
|
51
|
+
if not self.stop_hook:
|
|
52
|
+
self.op_is_distributed = True
|
|
54
53
|
|
|
55
54
|
@torch_device_guard
|
|
56
55
|
def forward(self, *args, **kwargs):
|
|
56
|
+
handle = distributed_func.get(self.op_name_)(*args, **kwargs)
|
|
57
57
|
if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]:
|
|
58
|
-
handle
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
else:
|
|
62
|
-
return distributed_func.get(self.op_name_)(*args, **kwargs)
|
|
58
|
+
if handle and hasattr(handle, 'wait'):
|
|
59
|
+
handle.wait()
|
|
60
|
+
return handle
|
|
63
61
|
|
|
64
62
|
|
|
65
63
|
def wrap_distributed_op(op_name, hook):
|
|
@@ -23,46 +23,6 @@ from msprobe.pytorch.common.log import logger
|
|
|
23
23
|
from msprobe.core.common.file_utils import load_yaml
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def remove_dropout():
|
|
27
|
-
if torch.__version__ > "1.8":
|
|
28
|
-
logger.info_on_rank_0("For precision comparison, the probability p in the dropout method is set to 0.")
|
|
29
|
-
import torch.nn.functional as F
|
|
30
|
-
from torch import _VF
|
|
31
|
-
from torch.overrides import has_torch_function_unary, handle_torch_function
|
|
32
|
-
|
|
33
|
-
def function_dropout(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
34
|
-
inplace: bool = False) -> torch.Tensor:
|
|
35
|
-
if has_torch_function_unary(input_tensor):
|
|
36
|
-
return handle_torch_function(
|
|
37
|
-
function_dropout, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
|
|
38
|
-
if p < 0.0 or p > 1.0:
|
|
39
|
-
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
40
|
-
return _VF.dropout_(input_tensor, 0., training) if inplace else _VF.dropout(input_tensor, 0., training)
|
|
41
|
-
|
|
42
|
-
def function_dropout2d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
43
|
-
inplace: bool = False) -> torch.Tensor:
|
|
44
|
-
if has_torch_function_unary(input_tensor):
|
|
45
|
-
return handle_torch_function(
|
|
46
|
-
function_dropout2d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
|
|
47
|
-
if p < 0.0 or p > 1.0:
|
|
48
|
-
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
49
|
-
return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
|
|
50
|
-
0., training)
|
|
51
|
-
|
|
52
|
-
def function_dropout3d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
53
|
-
inplace: bool = False) -> torch.Tensor:
|
|
54
|
-
if has_torch_function_unary(input_tensor):
|
|
55
|
-
return handle_torch_function(
|
|
56
|
-
function_dropout3d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
|
|
57
|
-
if p < 0.0 or p > 1.0:
|
|
58
|
-
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
59
|
-
return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
|
|
60
|
-
0., training)
|
|
61
|
-
|
|
62
|
-
F.dropout = function_dropout
|
|
63
|
-
F.dropout2d = function_dropout2d
|
|
64
|
-
F.dropout3d = function_dropout3d
|
|
65
|
-
|
|
66
26
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
67
27
|
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
|
|
68
28
|
|
|
@@ -19,7 +19,7 @@ import argparse
|
|
|
19
19
|
import ast
|
|
20
20
|
import heapq
|
|
21
21
|
|
|
22
|
-
from msprobe.
|
|
22
|
+
from msprobe.pytorch.common.log import logger
|
|
23
23
|
from msprobe.core.common.const import MonitorConst
|
|
24
24
|
from msprobe.core.common.file_utils import check_path_before_create, save_json, create_directory, remove_path, \
|
|
25
25
|
check_file_or_directory_path, load_json
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -12,21 +12,22 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
|
|
15
|
+
import itertools
|
|
16
16
|
import os
|
|
17
|
-
import sys
|
|
18
17
|
import statistics as st
|
|
18
|
+
import sys
|
|
19
19
|
from abc import ABC
|
|
20
|
+
from collections import defaultdict
|
|
20
21
|
from dataclasses import dataclass, field
|
|
21
22
|
from typing import List
|
|
22
|
-
from collections import defaultdict
|
|
23
23
|
|
|
24
24
|
import pandas as pd
|
|
25
|
+
import torch
|
|
25
26
|
from torch.utils.tensorboard import SummaryWriter
|
|
26
27
|
|
|
27
|
-
from msprobe.core.common.log import logger
|
|
28
|
-
from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv
|
|
29
28
|
from msprobe.core.common.const import FileCheckConst, MonitorConst
|
|
29
|
+
from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv
|
|
30
|
+
from msprobe.pytorch.common.log import logger
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
class ScanRule(ABC):
|
|
@@ -134,7 +135,7 @@ class AnomalyDataFactory(ABC):
|
|
|
134
135
|
raise ValueError("tag must be a tuple with length 2")
|
|
135
136
|
tag_name = tag[0]
|
|
136
137
|
param_name = tag_name.split('/')[0]
|
|
137
|
-
call_id = self.name2callid.get(
|
|
138
|
+
call_id = self.name2callid.get(tag_name, -1)
|
|
138
139
|
if MonitorConst.VPP_SEP in param_name:
|
|
139
140
|
vpp_stage = int(param_name.split(MonitorConst.VPP_SEP)[0])
|
|
140
141
|
else:
|
|
@@ -153,6 +154,24 @@ class AnomalyDataFactory(ABC):
|
|
|
153
154
|
)
|
|
154
155
|
|
|
155
156
|
|
|
157
|
+
class TrainStage:
|
|
158
|
+
DEFAULT_STAGE = -1
|
|
159
|
+
FORWARD_STAGE = 0
|
|
160
|
+
BACKWARD_STAGE = 1
|
|
161
|
+
OPTIMIZER_STAGE = 2
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
FORWARD_KEY = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT]
|
|
165
|
+
BACKWARD_KEY = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT,
|
|
166
|
+
MonitorConst.PRE_GRAD, MonitorConst.POST_GRAD, MonitorConst.ACC_GRAD]
|
|
167
|
+
OPTIMIZER_KEY = [MonitorConst.EXP_AVG, MonitorConst.EFXP_AVG_SQ]
|
|
168
|
+
TRAIN_STAGE = {
|
|
169
|
+
**{key_: TrainStage.FORWARD_STAGE for key_ in FORWARD_KEY},
|
|
170
|
+
**{key_: TrainStage.BACKWARD_STAGE for key_ in BACKWARD_KEY},
|
|
171
|
+
**{key_: TrainStage.OPTIMIZER_STAGE for key_ in OPTIMIZER_KEY}
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
|
|
156
175
|
@dataclass(eq=True)
|
|
157
176
|
class GradAnomalyData:
|
|
158
177
|
rank: int = 0
|
|
@@ -166,25 +185,48 @@ class GradAnomalyData:
|
|
|
166
185
|
group_mates: list = field(default=None, compare=False)
|
|
167
186
|
|
|
168
187
|
def __lt__(self, other):
|
|
188
|
+
"""
|
|
189
|
+
自定义比较函数,用于确定 GradAnomalyData 实例之间的顺序。
|
|
190
|
+
比较规则为:
|
|
191
|
+
step 和 micro_step 值越小优先级越高;
|
|
192
|
+
vpp 和 pp 在前向阶段值越小优先级越高,在非前向阶段值越大优先级越高;
|
|
193
|
+
call_id 值越小优先级越高。
|
|
194
|
+
"""
|
|
169
195
|
if not isinstance(other, GradAnomalyData):
|
|
170
196
|
return NotImplemented
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
197
|
+
|
|
198
|
+
self_train_stage = self.get_train_stage(self.tag_name)
|
|
199
|
+
other_train_stage = self.get_train_stage(other.tag_name)
|
|
200
|
+
|
|
201
|
+
def vpp_pp_comparator(anomaly):
|
|
202
|
+
"""
|
|
203
|
+
Determine the priority rule for vpp and pp based on train stage
|
|
204
|
+
Forward stage prefers smaller vpp and pp
|
|
205
|
+
Other stages prefer larger vpp and pp
|
|
206
|
+
"""
|
|
207
|
+
if self_train_stage == TrainStage.FORWARD_STAGE:
|
|
208
|
+
return anomaly.vpp_stage, anomaly.pp_stage
|
|
209
|
+
else:
|
|
210
|
+
return -anomaly.vpp_stage, -anomaly.pp_stage
|
|
211
|
+
|
|
212
|
+
self_cmp = [self.step, self.micro_step, self_train_stage, *vpp_pp_comparator(self), self.call_id]
|
|
213
|
+
other_cmp = [other.step, other.micro_step, other_train_stage, *vpp_pp_comparator(other), other.call_id]
|
|
214
|
+
return self_cmp < other_cmp
|
|
182
215
|
|
|
183
216
|
def __le__(self, other):
|
|
184
217
|
if not isinstance(other, GradAnomalyData):
|
|
185
218
|
return NotImplemented
|
|
186
219
|
return self == other or self < other
|
|
187
220
|
|
|
221
|
+
@staticmethod
|
|
222
|
+
def get_train_stage(tag_name):
|
|
223
|
+
"""
|
|
224
|
+
:param tag_name: "0:fc2_0/rank0/input", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/efxp_avg_sq"
|
|
225
|
+
:return: int, if forward return 0; if backward return 1; if optimizer return 2
|
|
226
|
+
"""
|
|
227
|
+
key_ = tag_name.split("/")[-1]
|
|
228
|
+
return TRAIN_STAGE.get(key_, TrainStage.DEFAULT_STAGE)
|
|
229
|
+
|
|
188
230
|
def to_dict(self):
|
|
189
231
|
return self.__dict__
|
|
190
232
|
|
|
@@ -198,7 +240,6 @@ class WriterInput:
|
|
|
198
240
|
path: str
|
|
199
241
|
ad_rules: list
|
|
200
242
|
job_id: str
|
|
201
|
-
anomaly_inform: bool = False
|
|
202
243
|
anomaly_factory: AnomalyDataFactory = None
|
|
203
244
|
ndigits: int = 6
|
|
204
245
|
step_count_per_record: int = 1
|
|
@@ -209,7 +250,6 @@ class BaseWriterWithAD:
|
|
|
209
250
|
self.tag2scalars = {}
|
|
210
251
|
self.ad_rules = writer_input.ad_rules
|
|
211
252
|
self.job_id = writer_input.job_id
|
|
212
|
-
self.anomaly_inform = writer_input.anomaly_inform
|
|
213
253
|
self.anomaly_factory = writer_input.anomaly_factory
|
|
214
254
|
self.anomalies = []
|
|
215
255
|
self.ndigits = writer_input.ndigits
|
|
@@ -242,6 +282,27 @@ class BaseWriterWithAD:
|
|
|
242
282
|
if self.anomaly_factory:
|
|
243
283
|
self.anomalies.append(self.anomaly_factory.create(tag, exception_message, global_step))
|
|
244
284
|
|
|
285
|
+
def write_metrics(self, ops, metric_value, step, prefix=''):
|
|
286
|
+
if not metric_value:
|
|
287
|
+
return
|
|
288
|
+
tensors = []
|
|
289
|
+
tags = list(itertools.product(metric_value.keys(), ops))
|
|
290
|
+
for op2tensor in metric_value.values():
|
|
291
|
+
tensors.extend(op2tensor.values())
|
|
292
|
+
if not tensors:
|
|
293
|
+
return
|
|
294
|
+
|
|
295
|
+
n_slices = len(tensors) // MonitorConst.SLICE_SIZE
|
|
296
|
+
with torch.no_grad():
|
|
297
|
+
for i in range(n_slices + 1):
|
|
298
|
+
begin = i * MonitorConst.SLICE_SIZE
|
|
299
|
+
end = (i+1) * MonitorConst.SLICE_SIZE
|
|
300
|
+
if begin == len(tensors):
|
|
301
|
+
continue
|
|
302
|
+
metric_list = torch.stack(tensors[begin:end]).cpu()
|
|
303
|
+
for tag, metric in zip(tags[begin:end], metric_list):
|
|
304
|
+
self.add_scalar(tag, metric, step)
|
|
305
|
+
|
|
245
306
|
def _ad(self, scalar_value, history):
|
|
246
307
|
return AnomalyScanner.scan(self.ad_rules, history, cur=scalar_value)
|
|
247
308
|
|
|
@@ -291,7 +352,7 @@ class CSVWriterWithAD(BaseWriterWithAD):
|
|
|
291
352
|
"""
|
|
292
353
|
if len(self.context_dict) == 0:
|
|
293
354
|
return
|
|
294
|
-
|
|
355
|
+
|
|
295
356
|
ster_start, step_end = self.get_step_interval(step)
|
|
296
357
|
filepath = os.path.join(self.log_dir, f'{prefix}_{ster_start}-{step_end}.csv')
|
|
297
358
|
if not os.path.exists(filepath):
|
|
@@ -304,7 +365,7 @@ class CSVWriterWithAD(BaseWriterWithAD):
|
|
|
304
365
|
new_data.append([name] + [step] + metric_value)
|
|
305
366
|
else:
|
|
306
367
|
new_data.append(name.split(MonitorConst.VPP_SEP) + [step] + metric_value)
|
|
307
|
-
new_data = pd.DataFrame(new_data).round(self.ndigits)
|
|
368
|
+
new_data = pd.DataFrame(new_data).round(self.ndigits).fillna("nan")
|
|
308
369
|
write_df_to_csv(new_data, filepath, mode='a+', header=False)
|
|
309
370
|
self.context_dict = defaultdict(list)
|
|
310
371
|
|
|
@@ -317,6 +378,30 @@ class CSVWriterWithAD(BaseWriterWithAD):
|
|
|
317
378
|
name = tag[0].split('/')[0]
|
|
318
379
|
self.context_dict[name].append(scalar_value.item())
|
|
319
380
|
|
|
381
|
+
def write_metrics(self, ops, metric_value, step, prefix=''):
|
|
382
|
+
super().write_metrics(ops, metric_value, step, prefix='')
|
|
383
|
+
|
|
384
|
+
# generate csv headers
|
|
385
|
+
# set hashmap to reduce the number of headers generated.
|
|
386
|
+
# 前向的norm用input.ops_和output.ops_,反向的用input_grad.ops_和output_grad.ops_
|
|
387
|
+
if prefix in {"actv", "actv_grad"}:
|
|
388
|
+
if prefix == "actv":
|
|
389
|
+
input_and_output = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT]
|
|
390
|
+
else:
|
|
391
|
+
input_and_output = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT]
|
|
392
|
+
ops_ = [MonitorConst.DOT.join(i) for i in itertools.product(input_and_output, ops)]
|
|
393
|
+
csv_header = ["module_name", "step", *ops_]
|
|
394
|
+
else:
|
|
395
|
+
csv_header = ["param_name", "step", *ops]
|
|
396
|
+
|
|
397
|
+
keys = list(metric_value.keys())
|
|
398
|
+
if keys and MonitorConst.VPP_SEP in keys[0]:
|
|
399
|
+
csv_header.insert(0, "vpp_stage")
|
|
400
|
+
|
|
401
|
+
self.header = csv_header
|
|
402
|
+
self.write_csv(prefix, step)
|
|
403
|
+
self.header = []
|
|
404
|
+
|
|
320
405
|
def close(self):
|
|
321
406
|
pass
|
|
322
407
|
|